Skip to main content

aster/providers/
lead_worker.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use std::ops::Deref;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage};
8use super::errors::ProviderError;
9use crate::conversation::message::{Message, MessageContent};
10use crate::model::ModelConfig;
11use rmcp::model::Tool;
12use rmcp::model::{Content, RawContent};
13
14/// A provider that switches between a lead model and a worker model based on turn count
15/// and can fallback to lead model on consecutive failures
16pub struct LeadWorkerProvider {
17    lead_provider: Arc<dyn Provider>,
18    worker_provider: Arc<dyn Provider>,
19    lead_turns: usize,
20    turn_count: Arc<Mutex<usize>>,
21    failure_count: Arc<Mutex<usize>>,
22    max_failures_before_fallback: usize,
23    fallback_turns: usize,
24    in_fallback_mode: Arc<Mutex<bool>>,
25    fallback_remaining: Arc<Mutex<usize>>,
26}
27
28impl LeadWorkerProvider {
29    /// Create a new LeadWorkerProvider
30    ///
31    /// # Arguments
32    /// * `lead_provider` - The provider to use for the initial turns
33    /// * `worker_provider` - The provider to use after lead_turns
34    /// * `lead_turns` - Number of turns to use the lead provider (default: 3)
35    pub fn new(
36        lead_provider: Arc<dyn Provider>,
37        worker_provider: Arc<dyn Provider>,
38        lead_turns: Option<usize>,
39    ) -> Self {
40        Self {
41            lead_provider,
42            worker_provider,
43            lead_turns: lead_turns.unwrap_or(3),
44            turn_count: Arc::new(Mutex::new(0)),
45            failure_count: Arc::new(Mutex::new(0)),
46            max_failures_before_fallback: 2, // Fallback after 2 consecutive failures
47            fallback_turns: 2,               // Use lead model for 2 turns when in fallback mode
48            in_fallback_mode: Arc::new(Mutex::new(false)),
49            fallback_remaining: Arc::new(Mutex::new(0)),
50        }
51    }
52
53    /// Create a new LeadWorkerProvider with custom settings
54    ///
55    /// # Arguments
56    /// * `lead_provider` - The provider to use for the initial turns
57    /// * `worker_provider` - The provider to use after lead_turns
58    /// * `lead_turns` - Number of turns to use the lead provider
59    /// * `failure_threshold` - Number of consecutive failures before fallback
60    /// * `fallback_turns` - Number of turns to use lead model in fallback mode
61    pub fn new_with_settings(
62        lead_provider: Arc<dyn Provider>,
63        worker_provider: Arc<dyn Provider>,
64        lead_turns: usize,
65        failure_threshold: usize,
66        fallback_turns: usize,
67    ) -> Self {
68        Self {
69            lead_provider,
70            worker_provider,
71            lead_turns,
72            turn_count: Arc::new(Mutex::new(0)),
73            failure_count: Arc::new(Mutex::new(0)),
74            max_failures_before_fallback: failure_threshold,
75            fallback_turns,
76            in_fallback_mode: Arc::new(Mutex::new(false)),
77            fallback_remaining: Arc::new(Mutex::new(0)),
78        }
79    }
80
81    /// Reset the turn counter and failure tracking (useful for new conversations)
82    pub async fn reset_turn_count(&self) {
83        let mut count = self.turn_count.lock().await;
84        *count = 0;
85        let mut failures = self.failure_count.lock().await;
86        *failures = 0;
87        let mut fallback = self.in_fallback_mode.lock().await;
88        *fallback = false;
89        let mut remaining = self.fallback_remaining.lock().await;
90        *remaining = 0;
91    }
92
93    /// Get the current turn count
94    pub async fn get_turn_count(&self) -> usize {
95        *self.turn_count.lock().await
96    }
97
98    /// Get the current failure count
99    pub async fn get_failure_count(&self) -> usize {
100        *self.failure_count.lock().await
101    }
102
103    /// Check if currently in fallback mode
104    pub async fn is_in_fallback_mode(&self) -> bool {
105        *self.in_fallback_mode.lock().await
106    }
107
108    /// Get the currently active provider based on turn count and fallback state
109    async fn get_active_provider(&self) -> Arc<dyn Provider> {
110        let count = *self.turn_count.lock().await;
111        let in_fallback = *self.in_fallback_mode.lock().await;
112
113        // Use lead provider if we're in initial turns OR in fallback mode
114        if count < self.lead_turns || in_fallback {
115            Arc::clone(&self.lead_provider)
116        } else {
117            Arc::clone(&self.worker_provider)
118        }
119    }
120
121    /// Handle the result of a completion attempt and update failure tracking
122    async fn handle_completion_result(
123        &self,
124        result: &Result<(Message, ProviderUsage), ProviderError>,
125    ) {
126        match result {
127            Ok((message, _usage)) => {
128                // Check for task-level failures in the response
129                let has_task_failure = self.detect_task_failures(message).await;
130
131                if has_task_failure {
132                    // Task failure detected - increment failure count
133                    let mut failures = self.failure_count.lock().await;
134                    *failures += 1;
135
136                    let failure_count = *failures;
137                    let turn_count = *self.turn_count.lock().await;
138
139                    tracing::warn!(
140                        "Task failure detected in response (failure count: {})",
141                        failure_count
142                    );
143
144                    // Check if we should trigger fallback
145                    if turn_count >= self.lead_turns
146                        && !*self.in_fallback_mode.lock().await
147                        && failure_count >= self.max_failures_before_fallback
148                    {
149                        let mut in_fallback = self.in_fallback_mode.lock().await;
150                        let mut fallback_remaining = self.fallback_remaining.lock().await;
151
152                        *in_fallback = true;
153                        *fallback_remaining = self.fallback_turns;
154                        *failures = 0; // Reset failure count when entering fallback
155
156                        tracing::warn!(
157                            "🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns",
158                            self.max_failures_before_fallback,
159                            self.fallback_turns
160                        );
161                    }
162                } else {
163                    // Success - reset failure count and handle fallback mode
164                    let mut failures = self.failure_count.lock().await;
165                    *failures = 0;
166
167                    let mut in_fallback = self.in_fallback_mode.lock().await;
168                    let mut fallback_remaining = self.fallback_remaining.lock().await;
169
170                    if *in_fallback {
171                        *fallback_remaining -= 1;
172                        if *fallback_remaining == 0 {
173                            *in_fallback = false;
174                            tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed");
175                        }
176                    }
177                }
178
179                // Increment turn count on any completion (success or task failure)
180                let mut count = self.turn_count.lock().await;
181                *count += 1;
182            }
183            Err(_) => {
184                // Technical failure - just log and let it bubble up
185                // For technical failures (API/LLM issues), we don't want to second-guess
186                // the model choice - just let the default model handle it
187                tracing::warn!(
188                    "Technical failure detected - API/LLM issue, will use default model"
189                );
190
191                // Don't increment turn count or failure tracking for technical failures
192                // as these are temporary infrastructure issues, not model capability issues
193            }
194        }
195    }
196
197    /// Detect task-level failures in the model's response
198    async fn detect_task_failures(&self, message: &Message) -> bool {
199        let mut failure_indicators = 0;
200
201        for content in &message.content {
202            match content {
203                MessageContent::ToolRequest(tool_request) => {
204                    // Check if tool request itself failed (malformed, etc.)
205                    if tool_request.tool_call.is_err() {
206                        failure_indicators += 1;
207                        tracing::debug!(
208                            "Failed tool request detected: {:?}",
209                            tool_request.tool_call
210                        );
211                    }
212                }
213                MessageContent::ToolResponse(tool_response) => {
214                    // Check if tool execution failed
215                    if let Err(tool_error) = &tool_response.tool_result {
216                        failure_indicators += 1;
217                        tracing::debug!("Tool execution failure detected: {:?}", tool_error);
218                    } else if let Ok(result) = &tool_response.tool_result {
219                        // Check tool output for error indicators
220                        if self.contains_error_indicators(&result.content) {
221                            failure_indicators += 1;
222                            tracing::debug!("Tool output contains error indicators");
223                        }
224                    }
225                }
226                MessageContent::Text(text_content) => {
227                    // Check for user correction patterns or error acknowledgments
228                    if self.contains_user_correction_patterns(&text_content.text) {
229                        failure_indicators += 1;
230                        tracing::debug!("User correction pattern detected in text");
231                    }
232                }
233                _ => {}
234            }
235        }
236
237        // Consider it a failure if we have multiple failure indicators
238        failure_indicators >= 1
239    }
240
241    /// Check if tool output contains error indicators
242    fn contains_error_indicators(&self, contents: &[Content]) -> bool {
243        for content in contents {
244            if let RawContent::Text(text_content) = content.deref() {
245                let text_lower = text_content.text.to_lowercase();
246
247                // Common error patterns in tool outputs
248                if text_lower.contains("error:")
249                    || text_lower.contains("failed:")
250                    || text_lower.contains("exception:")
251                    || text_lower.contains("traceback")
252                    || text_lower.contains("syntax error")
253                    || text_lower.contains("permission denied")
254                    || text_lower.contains("file not found")
255                    || text_lower.contains("command not found")
256                    || text_lower.contains("compilation failed")
257                    || text_lower.contains("test failed")
258                    || text_lower.contains("assertion failed")
259                {
260                    return true;
261                }
262            }
263        }
264        false
265    }
266
267    /// Check for user correction patterns in text
268    fn contains_user_correction_patterns(&self, text: &str) -> bool {
269        let text_lower = text.to_lowercase();
270
271        // Patterns indicating user is correcting or expressing dissatisfaction
272        text_lower.contains("that's wrong")
273            || text_lower.contains("that's not right")
274            || text_lower.contains("that doesn't work")
275            || text_lower.contains("try again")
276            || text_lower.contains("let me correct")
277            || text_lower.contains("actually, ")
278            || text_lower.contains("no, that's")
279            || text_lower.contains("that's incorrect")
280            || text_lower.contains("fix this")
281            || text_lower.contains("this is broken")
282            || text_lower.contains("this doesn't")
283            || text_lower.starts_with("no,")
284            || text_lower.starts_with("wrong")
285            || text_lower.starts_with("incorrect")
286    }
287}
288
289impl LeadWorkerProviderTrait for LeadWorkerProvider {
290    /// Get information about the lead and worker models for logging
291    fn get_model_info(&self) -> (String, String) {
292        let lead_model = self.lead_provider.get_model_config().model_name;
293        let worker_model = self.worker_provider.get_model_config().model_name;
294        (lead_model, worker_model)
295    }
296
297    /// Get the currently active model name
298    fn get_active_model(&self) -> String {
299        // Read from the global store which was set during complete()
300        use super::base::get_current_model;
301        get_current_model().unwrap_or_else(|| {
302            // Fallback to lead model if no current model is set
303            self.lead_provider.get_model_config().model_name
304        })
305    }
306
307    /// Get (lead_turns, failure_threshold, fallback_turns)
308    fn get_settings(&self) -> (usize, usize, usize) {
309        (
310            self.lead_turns,
311            self.max_failures_before_fallback,
312            self.fallback_turns,
313        )
314    }
315}
316
317#[async_trait]
318impl Provider for LeadWorkerProvider {
319    fn metadata() -> ProviderMetadata {
320        // This is a wrapper provider, so we return minimal metadata
321        ProviderMetadata::new(
322            "lead_worker",
323            "Lead/Worker Provider",
324            "A provider that switches between lead and worker models based on turn count",
325            "",     // No default model as this is determined by the wrapped providers
326            vec![], // No known models as this depends on wrapped providers
327            "",     // No doc link
328            vec![], // No config keys as configuration is done through wrapped providers
329        )
330    }
331
332    fn get_name(&self) -> &str {
333        // Return the lead provider's name as the default
334        self.lead_provider.get_name()
335    }
336
337    fn get_model_config(&self) -> ModelConfig {
338        // Return the lead provider's model config as the default
339        // In practice, this might need to be more sophisticated
340        self.lead_provider.get_model_config()
341    }
342
343    async fn complete_with_model(
344        &self,
345        _model_config: &ModelConfig,
346        system: &str,
347        messages: &[Message],
348        tools: &[Tool],
349    ) -> Result<(Message, ProviderUsage), ProviderError> {
350        // Get the active provider
351        let provider = self.get_active_provider().await;
352
353        // Log which provider is being used
354        let turn_count = *self.turn_count.lock().await;
355        let in_fallback = *self.in_fallback_mode.lock().await;
356        let fallback_remaining = *self.fallback_remaining.lock().await;
357
358        let provider_type = if turn_count < self.lead_turns {
359            "lead (initial)"
360        } else if in_fallback {
361            "lead (fallback)"
362        } else {
363            "worker"
364        };
365
366        // Get the active model name and update the global store
367        let active_model_name = if turn_count < self.lead_turns || in_fallback {
368            self.lead_provider.get_model_config().model_name.clone()
369        } else {
370            self.worker_provider.get_model_config().model_name.clone()
371        };
372
373        // Update the global current model store
374        super::base::set_current_model(&active_model_name);
375
376        if in_fallback {
377            tracing::info!(
378                "🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining) - Model: {}",
379                provider_type,
380                turn_count + 1,
381                fallback_remaining,
382                active_model_name
383            );
384        } else {
385            tracing::info!(
386                "Using {} provider for turn {} (lead_turns: {}) - Model: {}",
387                provider_type,
388                turn_count + 1,
389                self.lead_turns,
390                active_model_name
391            );
392        }
393
394        // Make the completion request
395        let result = provider.complete(system, messages, tools).await;
396
397        // For technical failures, try with default model (lead provider) instead
398        let final_result = match &result {
399            Err(_) => {
400                tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type);
401
402                // Try with lead provider as the default/fallback for technical failures
403                let default_result = self.lead_provider.complete(system, messages, tools).await;
404
405                match &default_result {
406                    Ok(_) => {
407                        tracing::info!(
408                            "✅ Default model (lead provider) succeeded after technical failure"
409                        );
410                        default_result
411                    }
412                    Err(_) => {
413                        tracing::error!("❌ Default model (lead provider) also failed - returning original error");
414                        result // Return the original error
415                    }
416                }
417            }
418            Ok(_) => result, // Success with original provider
419        };
420
421        // Handle the result and update tracking (only for successful completions)
422        self.handle_completion_result(&final_result).await;
423
424        final_result
425    }
426
427    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
428        // Combine models from both providers
429        let lead_models = self.lead_provider.fetch_supported_models().await?;
430        let worker_models = self.worker_provider.fetch_supported_models().await?;
431
432        match (lead_models, worker_models) {
433            (Some(lead), Some(worker)) => {
434                let mut all_models = lead;
435                all_models.extend(worker);
436                all_models.sort();
437                all_models.dedup();
438                Ok(Some(all_models))
439            }
440            (Some(models), None) | (None, Some(models)) => Ok(Some(models)),
441            (None, None) => Ok(None),
442        }
443    }
444
445    fn supports_embeddings(&self) -> bool {
446        // Support embeddings if either provider supports them
447        self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings()
448    }
449
450    async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
451        // Use the lead provider for embeddings if it supports them, otherwise use worker
452        if self.lead_provider.supports_embeddings() {
453            self.lead_provider.create_embeddings(texts).await
454        } else if self.worker_provider.supports_embeddings() {
455            self.worker_provider.create_embeddings(texts).await
456        } else {
457            Err(ProviderError::ExecutionError(
458                "Neither lead nor worker provider supports embeddings".to_string(),
459            ))
460        }
461    }
462
463    /// Check if this provider is a LeadWorkerProvider
464    fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
465        Some(self)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::conversation::message::{Message, MessageContent};
473    use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
474    use chrono::Utc;
475    use rmcp::model::{AnnotateAble, RawTextContent, Role};
476
477    #[derive(Clone)]
478    struct MockProvider {
479        name: String,
480        model_config: ModelConfig,
481    }
482
483    #[async_trait]
484    impl Provider for MockProvider {
485        fn metadata() -> ProviderMetadata {
486            ProviderMetadata::empty()
487        }
488
489        fn get_name(&self) -> &str {
490            "mock-lead"
491        }
492
493        fn get_model_config(&self) -> ModelConfig {
494            self.model_config.clone()
495        }
496
497        async fn complete_with_model(
498            &self,
499            _model_config: &ModelConfig,
500            _system: &str,
501            _messages: &[Message],
502            _tools: &[Tool],
503        ) -> Result<(Message, ProviderUsage), ProviderError> {
504            Ok((
505                Message::new(
506                    Role::Assistant,
507                    Utc::now().timestamp(),
508                    vec![MessageContent::Text(
509                        RawTextContent {
510                            text: format!("Response from {}", self.name),
511                            meta: None,
512                        }
513                        .no_annotation(),
514                    )],
515                ),
516                ProviderUsage::new(self.name.clone(), Usage::default()),
517            ))
518        }
519    }
520
521    #[tokio::test]
522    async fn test_lead_worker_switching() {
523        let lead_provider = Arc::new(MockProvider {
524            name: "lead".to_string(),
525            model_config: ModelConfig::new_or_fail("lead-model"),
526        });
527
528        let worker_provider = Arc::new(MockProvider {
529            name: "worker".to_string(),
530            model_config: ModelConfig::new_or_fail("worker-model"),
531        });
532
533        let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3));
534
535        // First three turns should use lead provider
536        for i in 0..3 {
537            let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
538            assert_eq!(usage.model, "lead");
539            assert_eq!(provider.get_turn_count().await, i + 1);
540            assert!(!provider.is_in_fallback_mode().await);
541        }
542
543        // Subsequent turns should use worker provider
544        for i in 3..6 {
545            let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
546            assert_eq!(usage.model, "worker");
547            assert_eq!(provider.get_turn_count().await, i + 1);
548            assert!(!provider.is_in_fallback_mode().await);
549        }
550
551        // Reset and verify it goes back to lead
552        provider.reset_turn_count().await;
553        assert_eq!(provider.get_turn_count().await, 0);
554        assert_eq!(provider.get_failure_count().await, 0);
555        assert!(!provider.is_in_fallback_mode().await);
556
557        let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
558        assert_eq!(usage.model, "lead");
559    }
560
561    #[tokio::test]
562    async fn test_technical_failure_retry() {
563        let lead_provider = Arc::new(MockFailureProvider {
564            name: "lead".to_string(),
565            model_config: ModelConfig::new_or_fail("lead-model"),
566            should_fail: false, // Lead provider works
567        });
568
569        let worker_provider = Arc::new(MockFailureProvider {
570            name: "worker".to_string(),
571            model_config: ModelConfig::new_or_fail("worker-model"),
572            should_fail: true, // Worker will fail
573        });
574
575        let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
576
577        // First two turns use lead (should succeed)
578        for _i in 0..2 {
579            let result = provider.complete("system", &[], &[]).await;
580            assert!(result.is_ok());
581            assert_eq!(result.unwrap().1.model, "lead");
582            assert!(!provider.is_in_fallback_mode().await);
583        }
584
585        // Next turn uses worker (will fail, but should retry with lead and succeed)
586        let result = provider.complete("system", &[], &[]).await;
587        assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
588        assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
589        assert_eq!(provider.get_failure_count().await, 0); // No failure tracking for technical failures
590        assert!(!provider.is_in_fallback_mode().await); // Not in fallback mode
591
592        // Another turn - should still try worker first, then retry with lead
593        let result = provider.complete("system", &[], &[]).await;
594        assert!(result.is_ok()); // Should succeed because lead provider is used as fallback
595        assert_eq!(result.unwrap().1.model, "lead"); // Should be lead provider
596        assert_eq!(provider.get_failure_count().await, 0); // Still no failure tracking
597        assert!(!provider.is_in_fallback_mode().await); // Still not in fallback mode
598    }
599
600    #[tokio::test]
601    async fn test_fallback_on_task_failures() {
602        // Test that task failures (not technical failures) still trigger fallback mode
603        // This would need a different mock that simulates task failures in successful responses
604        // For now, we'll test the fallback mode functionality directly
605        let lead_provider = Arc::new(MockFailureProvider {
606            name: "lead".to_string(),
607            model_config: ModelConfig::new_or_fail("lead-model"),
608            should_fail: false,
609        });
610
611        let worker_provider = Arc::new(MockFailureProvider {
612            name: "worker".to_string(),
613            model_config: ModelConfig::new_or_fail("worker-model"),
614            should_fail: false,
615        });
616
617        let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
618
619        // Simulate being in fallback mode
620        {
621            let mut in_fallback = provider.in_fallback_mode.lock().await;
622            *in_fallback = true;
623            let mut fallback_remaining = provider.fallback_remaining.lock().await;
624            *fallback_remaining = 2;
625            let mut turn_count = provider.turn_count.lock().await;
626            *turn_count = 4; // Past initial lead turns
627        }
628
629        // Should use lead provider in fallback mode
630        let result = provider.complete("system", &[], &[]).await;
631        assert!(result.is_ok());
632        assert_eq!(result.unwrap().1.model, "lead");
633        assert!(provider.is_in_fallback_mode().await);
634
635        // One more fallback turn
636        let result = provider.complete("system", &[], &[]).await;
637        assert!(result.is_ok());
638        assert_eq!(result.unwrap().1.model, "lead");
639        assert!(!provider.is_in_fallback_mode().await); // Should exit fallback mode
640    }
641
642    #[derive(Clone)]
643    struct MockFailureProvider {
644        name: String,
645        model_config: ModelConfig,
646        should_fail: bool,
647    }
648
649    #[async_trait]
650    impl Provider for MockFailureProvider {
651        fn metadata() -> ProviderMetadata {
652            ProviderMetadata::empty()
653        }
654
655        fn get_name(&self) -> &str {
656            "mock-lead"
657        }
658
659        fn get_model_config(&self) -> ModelConfig {
660            self.model_config.clone()
661        }
662
663        async fn complete_with_model(
664            &self,
665            _model_config: &ModelConfig,
666            _system: &str,
667            _messages: &[Message],
668            _tools: &[Tool],
669        ) -> Result<(Message, ProviderUsage), ProviderError> {
670            if self.should_fail {
671                Err(ProviderError::ExecutionError(
672                    "Simulated failure".to_string(),
673                ))
674            } else {
675                Ok((
676                    Message::new(
677                        Role::Assistant,
678                        Utc::now().timestamp(),
679                        vec![MessageContent::Text(
680                            RawTextContent {
681                                text: format!("Response from {}", self.name),
682                                meta: None,
683                            }
684                            .no_annotation(),
685                        )],
686                    ),
687                    ProviderUsage::new(self.name.clone(), Usage::default()),
688                ))
689            }
690        }
691    }
692}