llm_router 0.1.0

A high-performance router and load balancer for LLM APIs like ChatGPT
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
// Root lib.rs file
// This re-exports the core library functionality

// Make modules public for library usage
pub mod types;
// Removed handler module declaration
// #[cfg(feature = "axum")]
// pub mod handler; // Keep handler optional

use crate::types::{InstanceStatus, LLMInstance, ModelCapability, RouterError, RoutingStrategy};
// use futures_util::future::join_all; // REMOVED - unused import
// use rand::prelude::*; // Removed unused import
use std::{
    collections::HashMap,
    sync::{
        atomic::{AtomicBool, AtomicUsize, Ordering},
        Arc,
    },
    time::{Duration, Instant},
};
use tokio::sync::{oneshot, RwLock}; // Use RwLock for instances list
use tracing::{debug, error, info, instrument, warn}; // Use tracing macros
use futures::{stream, StreamExt}; // Import stream utilities

// Re-export core types for convenience
// Commented out to avoid duplicate re-exports
// pub use crate::types::{RouterError, RoutingStrategy};
// No need to re-export ModelCapability since it's already imported and used

const DEFAULT_HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(5);
const DEFAULT_HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(2);
const DEFAULT_HEALTH_CHECK_PATH: &str = "/health"; // Common health check endpoint
const DEFAULT_INSTANCE_TIMEOUT: Duration = Duration::from_secs(30); // Default timeout duration

// --- Router Core ---

#[derive(Debug)]
struct RouterInternalState {
    instances: RwLock<Vec<LLMInstance>>,
    strategy: RoutingStrategy,
    next_instance_index: AtomicUsize, // For RoundRobin
    _http_client: reqwest::Client, // Prefix with _ to silence warning for now
    instance_timeout_duration: Duration, // How long to timeout an instance
}

#[derive(Debug, Clone)] // Clone is cheap due to Arc
pub struct Router {
    internal: Arc<RouterInternalState>,
    // Keep handle to stop health check task on drop
    stop_health_check_tx: Arc<RwLock<Option<oneshot::Sender<()>>>>,
}

// --- Model-Instance mapping type ---
#[derive(Debug, Clone)]
pub struct ModelInstanceConfig {
    pub model_name: String,
    pub capabilities: Vec<ModelCapability>,
}

// --- Builder Pattern ---

#[derive(Debug)]
pub struct RouterBuilder {
    initial_instances: Vec<LLMInstance>,
    strategy: RoutingStrategy,
    health_check_interval: Duration,
    health_check_timeout: Duration,
    health_check_path: String,
    instance_timeout_duration: Duration,
    reqwest_client: Option<reqwest::Client>,
}

impl Default for RouterBuilder {
    fn default() -> Self {
        RouterBuilder {
            initial_instances: Vec::new(),
            strategy: RoutingStrategy::LoadBased,
            health_check_interval: DEFAULT_HEALTH_CHECK_INTERVAL,
            health_check_timeout: DEFAULT_HEALTH_CHECK_TIMEOUT,
            health_check_path: DEFAULT_HEALTH_CHECK_PATH.to_string(),
            instance_timeout_duration: DEFAULT_INSTANCE_TIMEOUT,
            reqwest_client: None,
        }
    }
}

impl RouterBuilder {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn initial_instances(mut self, instances: Vec<LLMInstance>) -> Self {
        self.initial_instances = instances;
        self
    }

    pub fn strategy(mut self, strategy: RoutingStrategy) -> Self {
        self.strategy = strategy;
        self
    }

    pub fn health_check_interval(mut self, interval: Duration) -> Self {
        self.health_check_interval = interval;
        self
    }

    pub fn health_check_timeout(mut self, timeout: Duration) -> Self {
        self.health_check_timeout = timeout;
        self
    }

    pub fn health_check_path(mut self, path: impl Into<String>) -> Self {
        self.health_check_path = path.into();
        // Ensure path starts with /
        if !self.health_check_path.starts_with('/') {
            self.health_check_path.insert(0, '/');
        }
        self
    }

    pub fn instance_timeout_duration(mut self, duration: Duration) -> Self {
        self.instance_timeout_duration = duration;
        self
    }

    /// Add an instance with model capability configuration
    pub fn instance_with_models(
        mut self,
        id: impl Into<String>,
        base_url: impl Into<String>,
        models: Vec<ModelInstanceConfig>,
    ) -> Self {
        let mut supported_models_map = HashMap::new();
        for config in models {
            supported_models_map.insert(config.model_name.clone(), config.capabilities.clone());
        }

        let instance = LLMInstance::new(
            id.into(), 
            base_url.into(), 
            InstanceStatus::Unknown, // Default to Unknown when added this way
            supported_models_map
        );
        
        // Add to the list
        self.initial_instances.push(instance);
        self
    }

    /// For backward compatibility
    pub fn instance(
        self,
        id: impl Into<String>,
        base_url: impl Into<String>,
    ) -> Self {
        // No models specified
        self.instance_with_models(id, base_url, Vec::new())
    }

    /// For backward compatibility
    pub fn instances(mut self, instances: Vec<(String, String)>) -> Self {
        for (id, base_url) in instances {
            self = self.instance(id, base_url);
        }
        self
    }

    /// Provide a custom reqwest client (e.g., with specific TLS config)
    pub fn http_client(mut self, client: reqwest::Client) -> Self {
        self.reqwest_client = Some(client);
        self
    }

    /// Builds the Router and starts the background health checking task.
    pub fn build(self) -> Router {
        let main_client = self.reqwest_client.unwrap_or_else(|| {
            reqwest::Client::builder()
                .timeout(Duration::from_secs(150)) // Default total timeout for main client
                .build()
                .expect("Failed to build default reqwest client")
        });

        // !!! Create a DEDICATED client for health checks !!!
        let health_check_client = reqwest::Client::builder()
            .timeout(self.health_check_timeout) // Use health check specific timeout
            .build()
            .expect("Failed to build health check reqwest client");

        let instances = self.initial_instances;

        let internal_state = Arc::new(RouterInternalState {
            instances: RwLock::new(instances),
            strategy: self.strategy,
            next_instance_index: AtomicUsize::new(0),
            _http_client: main_client, // Main client stored in state
            instance_timeout_duration: self.instance_timeout_duration,
        });

        let (stop_tx, stop_rx) = oneshot::channel::<()>();

        let router = Router {
            internal: internal_state.clone(),
            stop_health_check_tx: Arc::new(RwLock::new(Some(stop_tx))),
        };

        // Spawn the health check task, passing the dedicated client
        HealthCheck::spawn(
            internal_state,
            stop_rx,
            health_check_client, // Pass dedicated client
            self.health_check_interval,
            self.health_check_timeout,
            self.health_check_path,
        );

        router
    }
}

// --- Router Implementation ---

impl Router {
    /// Returns a builder to configure and create a Router instance.
    pub fn builder() -> RouterBuilder {
        RouterBuilder::new()
    }

    /// Retrieves a list of currently configured instances and their status.
    pub async fn get_instances(&self) -> Vec<LLMInstance> {
        self.internal.instances.read().await.clone()
    }

    /// Adds a new LLM instance to the router.
    /// If an instance with the same ID already exists, it returns an error.
    #[instrument(skip(self, id, base_url))]
    pub async fn add_instance(
        &self,
        id: impl Into<String>,
        base_url: impl Into<String>,
    ) -> Result<(), RouterError> {
        self.add_instance_with_models(id, base_url, Vec::new()).await
    }

    /// Adds a new LLM instance with model capabilities to the router.
    #[instrument(skip(self, id, base_url, models))]
    pub async fn add_instance_with_models(
        &self,
        id: impl Into<String>,
        base_url: impl Into<String>,
        models: Vec<ModelInstanceConfig>,
    ) -> Result<(), RouterError> {
        let instance_id = id.into();
        let base_url_str = base_url.into();

        // Validate URL format
        let _ = url::Url::parse(&base_url_str).map_err(RouterError::InvalidUrl)?;

        let mut instances = self.internal.instances.write().await;

        // Check for duplicate ID
        if instances.iter().any(|inst| inst.id == instance_id) {
            warn!(instance_id = %instance_id, "Attempted to add duplicate instance");
            return Err(RouterError::InstanceExists(instance_id));
        }

        // Create the instance
        let mut instance = LLMInstance {
            id: instance_id.clone(),
            base_url: base_url_str,
            active_requests: Arc::new(AtomicUsize::new(0)),
            status: Arc::new(RwLock::new(InstanceStatus::Unknown)), // Start as Unknown
            is_in_timeout: Arc::new(AtomicBool::new(false)),
            timeout_until: Arc::new(RwLock::new(None)),
            supported_models: Arc::new(RwLock::new(HashMap::new())), // Initialize empty
        };

        // Set up the supported models map
        let mut supported_models_map = HashMap::new();
        for config in models {
            supported_models_map.insert(config.model_name.clone(), config.capabilities.clone());
            info!(
                instance_id = %instance.id, 
                model = %config.model_name, 
                capabilities = ?config.capabilities, 
                "Adding model support for new instance"
            );
        }

        // Update the instance with model support
        let supported_models_guard = Arc::get_mut(&mut instance.supported_models)
            .expect("Failed to get mutable reference to supported_models");
        *supported_models_guard = RwLock::new(supported_models_map);

        instances.push(instance);
        info!(instance_id = %instance_id, "Added new instance");
        Ok(())
    }

    /// Adds or updates model capabilities for an existing instance.
    /// If the instance doesn't exist, it returns an error.
    #[instrument(skip(self, instance_id, model_name, capabilities))]
    pub async fn add_model_to_instance(
        &self,
        instance_id: &str,
        model_name: String,
        capabilities: Vec<ModelCapability>,
    ) -> Result<(), RouterError> {
        let instances = self.internal.instances.read().await;
        if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
            let mut supported_models = instance.supported_models.write().await;
            info!(
                instance_id = %instance_id, 
                model = %model_name, 
                capabilities = ?capabilities, 
                "Adding/Updating model support"
            );
            supported_models.insert(model_name, capabilities);
            Ok(())
        } else {
            error!(instance_id = %instance_id, "Instance not found when trying to add model");
            Err(RouterError::InstanceNotFound(instance_id.to_string()))
        }
    }

    /// Removes an instance from the router by its ID.
    /// Returns an error if the instance is not found.
    #[instrument(skip(self, id))]
    pub async fn remove_instance(&self, id: &str) -> Result<(), RouterError> {
        let mut instances = self.internal.instances.write().await;
        let initial_len = instances.len();
        instances.retain(|inst| inst.id != id);

        if instances.len() < initial_len {
            info!(instance_id = %id, "Removed instance");
            Ok(())
        } else {
            error!(instance_id = %id, "Instance not found for removal");
            Err(RouterError::InstanceNotFound(id.to_string()))
        }
    }

    /// Selects the next appropriate backend instance based on the configured strategy.
    /// Takes into account model requirements and instance health.
    #[instrument(skip(self, model_name, capability))]
    pub async fn select_instance_for_model(
        &self,
        model_name: &str,
        capability: ModelCapability,
    ) -> Result<LLMInstance, RouterError> {
        let instance_refs: Vec<LLMInstance> = {
            self.internal.instances.read().await.clone()
        };

        // Collect references to available instances first
        let mut available_refs = Vec::new();
        for instance in instance_refs.iter() {
            let status = instance.status.read().await.clone();
            let is_timed_out = instance.is_in_timeout.load(Ordering::SeqCst);
            let timeout_expiry = *instance.timeout_until.read().await;
            let now = Instant::now();

            if is_timed_out {
                if let Some(expiry) = timeout_expiry {
                    if now < expiry {
                        continue;
                    }
                }
            }

            if status != InstanceStatus::Healthy {
                continue;
            }

            let supports = self
                .instance_supports_model(&instance.id, model_name, &capability)
                .await?;
            if supports {
                available_refs.push(instance); // Push reference
            } else {
                debug!(instance_id = %instance.id, model = %model_name, capability = ?capability, "Instance does not support model/capability.");
            }
        }

        if available_refs.is_empty() {
            warn!(model = %model_name, capability = ?capability, "No healthy instances available for the required model and capability.");
            return Err(RouterError::NoHealthyInstancesForModel(
                model_name.to_string(),
                capability,
            ));
        }

        // Apply routing strategy using references, clone only the final result
        match self.internal.strategy {
            RoutingStrategy::RoundRobin => {
                let index = self
                    .internal
                    .next_instance_index
                    .fetch_add(1, Ordering::SeqCst);
                let selected_instance = available_refs[index % available_refs.len()].clone(); // Clone here
                debug!(instance_id = %selected_instance.id, strategy = "RoundRobin", "Selected instance");
                Ok(selected_instance)
            }
            RoutingStrategy::LoadBased => {
                available_refs
                    .get(0) // -> Option<&&LLMInstance>
                    .map(|instance_ref_ref| (*instance_ref_ref).clone()) // -> Option<LLMInstance>
                    .ok_or_else(|| {
                        error!("LoadBased selection failed: available_refs was non-empty but get(0) yielded None.");
                        RouterError::NoHealthyInstancesForModel(model_name.to_string(), capability)
                    })
            }
        }
    }

    /// Selects the next backend instance based only on health and strategy, ignoring models.
    #[instrument(skip(self))]
    pub async fn select_next_instance(&self) -> Result<LLMInstance, RouterError> {
         let instance_refs: Vec<LLMInstance> = {
            self.internal.instances.read().await.clone()
        };

        // Collect references to available instances first
        let mut available_refs = Vec::new();
        for instance in instance_refs.iter() {
            let status = instance.status.read().await.clone();
            let is_timed_out = instance.is_in_timeout.load(Ordering::SeqCst);
            let timeout_expiry = *instance.timeout_until.read().await;
            let now = Instant::now();

            if is_timed_out {
                if let Some(expiry) = timeout_expiry {
                    if now < expiry {
                        continue;
                    }
                }
            }

            if status == InstanceStatus::Healthy {
                available_refs.push(instance); // Push reference
            } else {
                debug!(instance_id = %instance.id, status = ?status, "Instance is not healthy.");
            }
        }

        if available_refs.is_empty() {
            warn!("No healthy instances available for selection.");
            return Err(RouterError::NoHealthyInstances);
        }

        // Apply routing strategy using references, clone only the final result
        match self.internal.strategy {
            RoutingStrategy::RoundRobin => {
                let index = self
                    .internal
                    .next_instance_index
                    .fetch_add(1, Ordering::SeqCst);
                 let selected_instance = available_refs[index % available_refs.len()].clone(); // Clone here
                debug!(instance_id = %selected_instance.id, strategy = "RoundRobin", "Selected instance");
                Ok(selected_instance)
            }
            RoutingStrategy::LoadBased => {
                 available_refs
                    .get(0)
                    .map(|instance_ref_ref| (*instance_ref_ref).clone())
                    .ok_or_else(|| {
                        error!("LoadBased selection failed: available_refs was non-empty but get(0) yielded None.");
                        RouterError::NoHealthyInstances
                    })
            }
        }
    }

    /// Increments the active request count for a given instance.
    pub async fn increment_request_count(&self, instance_id: &str) -> Result<(), RouterError> {
        let instances = self.internal.instances.read().await;
        if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
             instance.active_requests.fetch_add(1, Ordering::SeqCst);
             Ok(())
        } else {
            Err(RouterError::InstanceNotFound(instance_id.to_string()))
        }
    }

    /// Decrements the active request count for a given instance.
    pub async fn decrement_request_count(&self, instance_id: &str) -> Result<(), RouterError> {
        let instances = self.internal.instances.read().await;
        if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
             instance.active_requests.fetch_sub(1, Ordering::SeqCst);
             Ok(())
        } else {
            Err(RouterError::InstanceNotFound(instance_id.to_string()))
        }
    }

    /// Marks an instance as timed out for a specific duration.
    pub async fn timeout_instance(&self, instance_id: &str) -> Result<(), RouterError> {
        let instances = self.internal.instances.read().await;
        if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
            let timeout_duration = self.internal.instance_timeout_duration;
            let expiry_time = Instant::now() + timeout_duration;
            instance.is_in_timeout.store(true, Ordering::SeqCst);
            *instance.timeout_until.write().await = Some(expiry_time);
            // Also mark as timed out status
            *instance.status.write().await = InstanceStatus::TimedOut;
            warn!(instance_id = %instance_id, duration = ?timeout_duration, "Instance timed out.");
            Ok(())
        } else {
            error!(instance_id = %instance_id, "Instance not found for timeout.");
            Err(RouterError::InstanceNotFound(instance_id.to_string()))
        }
    }

    /// Checks if a specific instance supports a given model and capability.
    async fn instance_supports_model(
        &self,
        instance_id: &str,
        model_name: &str,
        capability: &ModelCapability,
    ) -> Result<bool, RouterError> {
        let instances = self.internal.instances.read().await;
        if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
            let supported_models = instance.supported_models.read().await;
            if let Some(capabilities) = supported_models.get(model_name) {
                Ok(capabilities.contains(capability))
            } else {
                Ok(false) // Model not listed for this instance
            }
        } else {
            Err(RouterError::InstanceNotFound(instance_id.to_string()))
        }
    }
}

// RAII guard for request counting
#[derive(Debug)]
pub struct RequestTracker {
    router: Router, // Clone the router (cheap due to Arc)
    instance_id: String,
    // Flag to indicate if increment happened, to avoid decrementing if creation failed.
    incremented: bool,
}

impl RequestTracker {
    // Needs to be async now
    pub async fn new(router: Router, instance_id: String) -> Self {
        // Increment count on creation
        let increment_result = router.increment_request_count(&instance_id).await;
        let incremented = match increment_result {
            Ok(_) => true,
            Err(e) => {
                 error!(instance_id = %instance_id, error = ?e, "Failed to increment request count in RequestTracker");
                 false
            }
        };
        RequestTracker { router, instance_id, incremented }
    }
}

impl Drop for RequestTracker {
    fn drop(&mut self) {
        // Only decrement if increment succeeded
        if self.incremented {
            // Decrement count on drop
            // We cannot call async .decrement_request_count() here.
            // This requires spawning a task, which is often discouraged in Drop.
            // A potential workaround is to send the ID to a cleanup task, 
            // or accept that counts might be slightly off if drop happens unexpectedly.
            // For now, we'll spawn a task, acknowledging the potential issues.
            let router = self.router.clone();
            let instance_id = self.instance_id.clone();
            tokio::spawn(async move {
                if let Err(e) = router.decrement_request_count(&instance_id).await {
                     error!(instance_id = %instance_id, error = ?e, "Failed to decrement request count in RequestTracker drop");
                }
            });
        }
    }
}

// --- Health Checking Task ---
struct HealthCheck;

impl HealthCheck {
    fn spawn(
        state: Arc<RouterInternalState>,
        mut stop_rx: oneshot::Receiver<()>,
        health_check_client: reqwest::Client, // Accept dedicated client
        health_check_interval_duration: Duration,
        _health_check_timeout: Duration, // Prefix with _ to silence warning
        health_check_path: String,
    ) {
        tokio::spawn(async move {
            let mut interval = tokio::time::interval(health_check_interval_duration);
            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);

            loop {
                tokio::select! {
                    _ = interval.tick() => {
                        debug!("Running health check cycle...");
                        // Pass the dedicated client down
                        Self::perform_checks(&state, &health_check_client, &health_check_path, _health_check_timeout).await;
                    }
                    _ = &mut stop_rx => {
                        info!("Stopping health check task.");
                        break;
                    }
                }
            }
        });
    }

    async fn perform_checks(
        state: &Arc<RouterInternalState>,
        health_check_client: &reqwest::Client, // Use dedicated client passed as argument
        health_check_path: &str,
        _health_check_timeout: Duration, // Prefix with _ to silence warning
    ) {
        let instances = state.instances.read().await.clone();
        // No longer need to clone client from state
        // let client = state.http_client.clone(); 
        const CONCURRENCY_LIMIT: usize = 10;

        stream::iter(instances.into_iter())
            .map(|instance| {
                // Clone the DEDICATED client for each task
                let client = health_check_client.clone(); 
                let path = health_check_path.to_string();
                async move {
                    // Skip check if instance is explicitly timed out
                    if instance.is_in_timeout.load(Ordering::SeqCst) {
                        let expiry = *instance.timeout_until.read().await;
                        if let Some(exp) = expiry {
                            if Instant::now() < exp {
                                debug!(instance_id=%instance.id, "Skipping health check, instance in timeout.");
                                return; // Still timed out
                            } else {
                                instance.is_in_timeout.store(false, Ordering::SeqCst);
                                *instance.timeout_until.write().await = None;
                                info!(instance_id=%instance.id, "Instance timeout expired during health check cycle.");
                            }
                        } else {
                            instance.is_in_timeout.store(false, Ordering::SeqCst);
                        }
                    }

                    let url = format!("{}/{}", instance.base_url.trim_end_matches('/'), path.trim_start_matches('/'));
                    // Use the passed-in client, timeout is already set on the client itself
                    let result = client.get(&url).send().await; 
                    let new_status = match result {
                        Ok(response) => {
                            if response.status().is_success() {
                                InstanceStatus::Healthy
                            } else {
                                warn!(instance_id=%instance.id, status=%response.status(), url=%url, "Health check failed with status");
                                InstanceStatus::Unhealthy
                            }
                        }
                        Err(e) => {
                            warn!(instance_id=%instance.id, error=?e, url=%url, "Health check request failed");
                            InstanceStatus::Unhealthy
                        }
                    };

                    let mut status_guard = instance.status.write().await;
                    if *status_guard != new_status {
                        info!(instance_id=%instance.id, old_status=?*status_guard, new_status=?new_status, "Instance health status changed");
                        *status_guard = new_status;
                    }
                }
            })
            .buffer_unordered(CONCURRENCY_LIMIT)
            .collect::<()>()
            .await;

        debug!("Finished health check cycle.");
    }
}

// --- Router Drop Implementation ---
impl Drop for Router {
    fn drop(&mut self) {
        // Attempt to send stop signal to health check task using try_write to avoid blocking
        if let Ok(mut guard) = self.stop_health_check_tx.try_write() {
            if let Some(tx) = guard.take() {
                let _ = tx.send(()); // Ignore error if receiver is already dropped
                info!("Sent stop signal to health check task via try_write.");
            }
        } else {
            // Optional: Log if try_write failed, indicating potential contention 
            // or issue during shutdown, but don't panic.
            // warn!("Could not acquire write lock on stop_health_check_tx during drop.");
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*; // Import everything from the parent module (lib.rs)
    use crate::types::{ModelCapability, InstanceStatus}; // Import necessary types
    use std::time::Duration;

    // Helper function to create a basic router for testing
    fn create_test_router() -> Router {
        Router::builder()
            .strategy(RoutingStrategy::RoundRobin)
            .build()
    }

    #[tokio::test]
    async fn test_builder_defaults() {
        let builder = RouterBuilder::new();
        assert_eq!(builder.strategy, RoutingStrategy::LoadBased); // Default is LoadBased
        assert_eq!(builder.health_check_interval, DEFAULT_HEALTH_CHECK_INTERVAL);
        assert_eq!(builder.health_check_timeout, DEFAULT_HEALTH_CHECK_TIMEOUT);
        assert_eq!(builder.health_check_path, DEFAULT_HEALTH_CHECK_PATH);
        assert_eq!(builder.instance_timeout_duration, DEFAULT_INSTANCE_TIMEOUT);
        assert!(builder.initial_instances.is_empty());
        assert!(builder.reqwest_client.is_none());
    }

    #[tokio::test]
    async fn test_builder_custom_config() {
        let client = reqwest::Client::new();
        let builder = RouterBuilder::new()
            .strategy(RoutingStrategy::RoundRobin)
            .health_check_interval(Duration::from_secs(10))
            .health_check_timeout(Duration::from_secs(5))
            .health_check_path("/status")
            .instance_timeout_duration(Duration::from_secs(60))
            .http_client(client.clone())
            .instance("test1", "http://localhost:8080");

        assert_eq!(builder.strategy, RoutingStrategy::RoundRobin);
        assert_eq!(builder.health_check_interval, Duration::from_secs(10));
        assert_eq!(builder.health_check_timeout, Duration::from_secs(5));
        assert_eq!(builder.health_check_path, "/status");
        assert_eq!(builder.instance_timeout_duration, Duration::from_secs(60));
        assert_eq!(builder.initial_instances.len(), 1);
        assert!(builder.reqwest_client.is_some());
    }

    #[tokio::test]
    async fn test_add_remove_instance() {
        let router = create_test_router();
        assert!(router.get_instances().await.is_empty());

        // Add instance
        let add_result = router.add_instance("inst1", "http://127.0.0.1:1111").await;
        assert!(add_result.is_ok());
        let instances = router.get_instances().await;
        assert_eq!(instances.len(), 1);
        assert_eq!(instances[0].id, "inst1");
        assert_eq!(*instances[0].status.read().await, InstanceStatus::Unknown);

        // Add duplicate instance
        let add_duplicate_result = router.add_instance("inst1", "http://127.0.0.1:2222").await;
        assert!(add_duplicate_result.is_err());
        assert!(matches!(add_duplicate_result.unwrap_err(), RouterError::InstanceExists(_)));
        assert_eq!(router.get_instances().await.len(), 1);

        // Remove instance
        let remove_result = router.remove_instance("inst1").await;
        assert!(remove_result.is_ok());
        assert!(router.get_instances().await.is_empty());

        // Remove non-existent instance
        let remove_nonexistent_result = router.remove_instance("inst2").await;
        assert!(remove_nonexistent_result.is_err());
        assert!(matches!(remove_nonexistent_result.unwrap_err(), RouterError::InstanceNotFound(_)));
    }
    
    #[tokio::test]
    async fn test_add_instance_with_models() {
        let router = create_test_router();
        let models = vec![
            ModelInstanceConfig { model_name: "gpt-4".to_string(), capabilities: vec![ModelCapability::Chat] },
            ModelInstanceConfig { model_name: "text-embed".to_string(), capabilities: vec![ModelCapability::Embedding] },
        ];
        let add_result = router.add_instance_with_models("inst_models", "http://127.0.0.1:3333", models).await;
        assert!(add_result.is_ok());

        let instances = router.get_instances().await;
        assert_eq!(instances.len(), 1);
        let instance = &instances[0];
        assert_eq!(instance.id, "inst_models");
        let supported_models = instance.supported_models.read().await;
        assert_eq!(supported_models.len(), 2);
        assert!(supported_models.contains_key("gpt-4"));
        assert!(supported_models.contains_key("text-embed"));
        assert_eq!(supported_models.get("gpt-4").unwrap(), &vec![ModelCapability::Chat]);
    }

    #[tokio::test]
    async fn test_add_model_to_instance() {
        let router = create_test_router();
        router.add_instance("inst_add_model", "http://127.0.0.1:4444").await.unwrap();

        let add_model_result = router.add_model_to_instance(
            "inst_add_model",
            "new-model".to_string(),
            vec![ModelCapability::Completion]
        ).await;
        assert!(add_model_result.is_ok());

        let instances = router.get_instances().await;
        let instance = instances.iter().find(|i| i.id == "inst_add_model").unwrap();
        let supported_models = instance.supported_models.read().await;
        assert_eq!(supported_models.len(), 1);
        assert!(supported_models.contains_key("new-model"));
        assert_eq!(supported_models.get("new-model").unwrap(), &vec![ModelCapability::Completion]);

        // Add to non-existent instance
        let add_to_nonexistent = router.add_model_to_instance(
            "nonexistent", 
            "test".to_string(), 
            vec![]
        ).await;
        assert!(add_to_nonexistent.is_err());
        assert!(matches!(add_to_nonexistent.unwrap_err(), RouterError::InstanceNotFound(_)));
    }
    
    // Add more unit tests here for:
    // - instance_supports_model
    // - select_instance_for_model (requires mock instances/health checks)
    // - select_next_instance (requires mock instances/health checks)
    // - timeout_instance
    // - request counting (RequestTracker)
    // - HealthCheck logic (might need more involved mocking)
}