1use anyhow::Result;
2use async_trait::async_trait;
3use std::ops::Deref;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use super::base::{LeadWorkerProviderTrait, Provider, ProviderMetadata, ProviderUsage};
8use super::errors::ProviderError;
9use crate::conversation::message::{Message, MessageContent};
10use crate::model::ModelConfig;
11use rmcp::model::Tool;
12use rmcp::model::{Content, RawContent};
13
14pub struct LeadWorkerProvider {
17 lead_provider: Arc<dyn Provider>,
18 worker_provider: Arc<dyn Provider>,
19 lead_turns: usize,
20 turn_count: Arc<Mutex<usize>>,
21 failure_count: Arc<Mutex<usize>>,
22 max_failures_before_fallback: usize,
23 fallback_turns: usize,
24 in_fallback_mode: Arc<Mutex<bool>>,
25 fallback_remaining: Arc<Mutex<usize>>,
26}
27
28impl LeadWorkerProvider {
29 pub fn new(
36 lead_provider: Arc<dyn Provider>,
37 worker_provider: Arc<dyn Provider>,
38 lead_turns: Option<usize>,
39 ) -> Self {
40 Self {
41 lead_provider,
42 worker_provider,
43 lead_turns: lead_turns.unwrap_or(3),
44 turn_count: Arc::new(Mutex::new(0)),
45 failure_count: Arc::new(Mutex::new(0)),
46 max_failures_before_fallback: 2, fallback_turns: 2, in_fallback_mode: Arc::new(Mutex::new(false)),
49 fallback_remaining: Arc::new(Mutex::new(0)),
50 }
51 }
52
53 pub fn new_with_settings(
62 lead_provider: Arc<dyn Provider>,
63 worker_provider: Arc<dyn Provider>,
64 lead_turns: usize,
65 failure_threshold: usize,
66 fallback_turns: usize,
67 ) -> Self {
68 Self {
69 lead_provider,
70 worker_provider,
71 lead_turns,
72 turn_count: Arc::new(Mutex::new(0)),
73 failure_count: Arc::new(Mutex::new(0)),
74 max_failures_before_fallback: failure_threshold,
75 fallback_turns,
76 in_fallback_mode: Arc::new(Mutex::new(false)),
77 fallback_remaining: Arc::new(Mutex::new(0)),
78 }
79 }
80
81 pub async fn reset_turn_count(&self) {
83 let mut count = self.turn_count.lock().await;
84 *count = 0;
85 let mut failures = self.failure_count.lock().await;
86 *failures = 0;
87 let mut fallback = self.in_fallback_mode.lock().await;
88 *fallback = false;
89 let mut remaining = self.fallback_remaining.lock().await;
90 *remaining = 0;
91 }
92
93 pub async fn get_turn_count(&self) -> usize {
95 *self.turn_count.lock().await
96 }
97
98 pub async fn get_failure_count(&self) -> usize {
100 *self.failure_count.lock().await
101 }
102
103 pub async fn is_in_fallback_mode(&self) -> bool {
105 *self.in_fallback_mode.lock().await
106 }
107
108 async fn get_active_provider(&self) -> Arc<dyn Provider> {
110 let count = *self.turn_count.lock().await;
111 let in_fallback = *self.in_fallback_mode.lock().await;
112
113 if count < self.lead_turns || in_fallback {
115 Arc::clone(&self.lead_provider)
116 } else {
117 Arc::clone(&self.worker_provider)
118 }
119 }
120
121 async fn handle_completion_result(
123 &self,
124 result: &Result<(Message, ProviderUsage), ProviderError>,
125 ) {
126 match result {
127 Ok((message, _usage)) => {
128 let has_task_failure = self.detect_task_failures(message).await;
130
131 if has_task_failure {
132 let mut failures = self.failure_count.lock().await;
134 *failures += 1;
135
136 let failure_count = *failures;
137 let turn_count = *self.turn_count.lock().await;
138
139 tracing::warn!(
140 "Task failure detected in response (failure count: {})",
141 failure_count
142 );
143
144 if turn_count >= self.lead_turns
146 && !*self.in_fallback_mode.lock().await
147 && failure_count >= self.max_failures_before_fallback
148 {
149 let mut in_fallback = self.in_fallback_mode.lock().await;
150 let mut fallback_remaining = self.fallback_remaining.lock().await;
151
152 *in_fallback = true;
153 *fallback_remaining = self.fallback_turns;
154 *failures = 0; tracing::warn!(
157 "🔄 SWITCHING TO LEAD MODEL: Entering fallback mode after {} consecutive task failures - using lead model for {} turns",
158 self.max_failures_before_fallback,
159 self.fallback_turns
160 );
161 }
162 } else {
163 let mut failures = self.failure_count.lock().await;
165 *failures = 0;
166
167 let mut in_fallback = self.in_fallback_mode.lock().await;
168 let mut fallback_remaining = self.fallback_remaining.lock().await;
169
170 if *in_fallback {
171 *fallback_remaining -= 1;
172 if *fallback_remaining == 0 {
173 *in_fallback = false;
174 tracing::info!("✅ SWITCHING BACK TO WORKER MODEL: Exiting fallback mode - worker model resumed");
175 }
176 }
177 }
178
179 let mut count = self.turn_count.lock().await;
181 *count += 1;
182 }
183 Err(_) => {
184 tracing::warn!(
188 "Technical failure detected - API/LLM issue, will use default model"
189 );
190
191 }
194 }
195 }
196
197 async fn detect_task_failures(&self, message: &Message) -> bool {
199 let mut failure_indicators = 0;
200
201 for content in &message.content {
202 match content {
203 MessageContent::ToolRequest(tool_request) => {
204 if tool_request.tool_call.is_err() {
206 failure_indicators += 1;
207 tracing::debug!(
208 "Failed tool request detected: {:?}",
209 tool_request.tool_call
210 );
211 }
212 }
213 MessageContent::ToolResponse(tool_response) => {
214 if let Err(tool_error) = &tool_response.tool_result {
216 failure_indicators += 1;
217 tracing::debug!("Tool execution failure detected: {:?}", tool_error);
218 } else if let Ok(result) = &tool_response.tool_result {
219 if self.contains_error_indicators(&result.content) {
221 failure_indicators += 1;
222 tracing::debug!("Tool output contains error indicators");
223 }
224 }
225 }
226 MessageContent::Text(text_content) => {
227 if self.contains_user_correction_patterns(&text_content.text) {
229 failure_indicators += 1;
230 tracing::debug!("User correction pattern detected in text");
231 }
232 }
233 _ => {}
234 }
235 }
236
237 failure_indicators >= 1
239 }
240
241 fn contains_error_indicators(&self, contents: &[Content]) -> bool {
243 for content in contents {
244 if let RawContent::Text(text_content) = content.deref() {
245 let text_lower = text_content.text.to_lowercase();
246
247 if text_lower.contains("error:")
249 || text_lower.contains("failed:")
250 || text_lower.contains("exception:")
251 || text_lower.contains("traceback")
252 || text_lower.contains("syntax error")
253 || text_lower.contains("permission denied")
254 || text_lower.contains("file not found")
255 || text_lower.contains("command not found")
256 || text_lower.contains("compilation failed")
257 || text_lower.contains("test failed")
258 || text_lower.contains("assertion failed")
259 {
260 return true;
261 }
262 }
263 }
264 false
265 }
266
267 fn contains_user_correction_patterns(&self, text: &str) -> bool {
269 let text_lower = text.to_lowercase();
270
271 text_lower.contains("that's wrong")
273 || text_lower.contains("that's not right")
274 || text_lower.contains("that doesn't work")
275 || text_lower.contains("try again")
276 || text_lower.contains("let me correct")
277 || text_lower.contains("actually, ")
278 || text_lower.contains("no, that's")
279 || text_lower.contains("that's incorrect")
280 || text_lower.contains("fix this")
281 || text_lower.contains("this is broken")
282 || text_lower.contains("this doesn't")
283 || text_lower.starts_with("no,")
284 || text_lower.starts_with("wrong")
285 || text_lower.starts_with("incorrect")
286 }
287}
288
289impl LeadWorkerProviderTrait for LeadWorkerProvider {
290 fn get_model_info(&self) -> (String, String) {
292 let lead_model = self.lead_provider.get_model_config().model_name;
293 let worker_model = self.worker_provider.get_model_config().model_name;
294 (lead_model, worker_model)
295 }
296
297 fn get_active_model(&self) -> String {
299 use super::base::get_current_model;
301 get_current_model().unwrap_or_else(|| {
302 self.lead_provider.get_model_config().model_name
304 })
305 }
306
307 fn get_settings(&self) -> (usize, usize, usize) {
309 (
310 self.lead_turns,
311 self.max_failures_before_fallback,
312 self.fallback_turns,
313 )
314 }
315}
316
317#[async_trait]
318impl Provider for LeadWorkerProvider {
319 fn metadata() -> ProviderMetadata {
320 ProviderMetadata::new(
322 "lead_worker",
323 "Lead/Worker Provider",
324 "A provider that switches between lead and worker models based on turn count",
325 "", vec![], "", vec![], )
330 }
331
332 fn get_name(&self) -> &str {
333 self.lead_provider.get_name()
335 }
336
337 fn get_model_config(&self) -> ModelConfig {
338 self.lead_provider.get_model_config()
341 }
342
343 async fn complete_with_model(
344 &self,
345 _model_config: &ModelConfig,
346 system: &str,
347 messages: &[Message],
348 tools: &[Tool],
349 ) -> Result<(Message, ProviderUsage), ProviderError> {
350 let provider = self.get_active_provider().await;
352
353 let turn_count = *self.turn_count.lock().await;
355 let in_fallback = *self.in_fallback_mode.lock().await;
356 let fallback_remaining = *self.fallback_remaining.lock().await;
357
358 let provider_type = if turn_count < self.lead_turns {
359 "lead (initial)"
360 } else if in_fallback {
361 "lead (fallback)"
362 } else {
363 "worker"
364 };
365
366 let active_model_name = if turn_count < self.lead_turns || in_fallback {
368 self.lead_provider.get_model_config().model_name.clone()
369 } else {
370 self.worker_provider.get_model_config().model_name.clone()
371 };
372
373 super::base::set_current_model(&active_model_name);
375
376 if in_fallback {
377 tracing::info!(
378 "🔄 Using {} provider for turn {} (FALLBACK MODE: {} turns remaining) - Model: {}",
379 provider_type,
380 turn_count + 1,
381 fallback_remaining,
382 active_model_name
383 );
384 } else {
385 tracing::info!(
386 "Using {} provider for turn {} (lead_turns: {}) - Model: {}",
387 provider_type,
388 turn_count + 1,
389 self.lead_turns,
390 active_model_name
391 );
392 }
393
394 let result = provider.complete(system, messages, tools).await;
396
397 let final_result = match &result {
399 Err(_) => {
400 tracing::warn!("Technical failure with {} provider, retrying with default model (lead provider)", provider_type);
401
402 let default_result = self.lead_provider.complete(system, messages, tools).await;
404
405 match &default_result {
406 Ok(_) => {
407 tracing::info!(
408 "✅ Default model (lead provider) succeeded after technical failure"
409 );
410 default_result
411 }
412 Err(_) => {
413 tracing::error!("❌ Default model (lead provider) also failed - returning original error");
414 result }
416 }
417 }
418 Ok(_) => result, };
420
421 self.handle_completion_result(&final_result).await;
423
424 final_result
425 }
426
427 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
428 let lead_models = self.lead_provider.fetch_supported_models().await?;
430 let worker_models = self.worker_provider.fetch_supported_models().await?;
431
432 match (lead_models, worker_models) {
433 (Some(lead), Some(worker)) => {
434 let mut all_models = lead;
435 all_models.extend(worker);
436 all_models.sort();
437 all_models.dedup();
438 Ok(Some(all_models))
439 }
440 (Some(models), None) | (None, Some(models)) => Ok(Some(models)),
441 (None, None) => Ok(None),
442 }
443 }
444
445 fn supports_embeddings(&self) -> bool {
446 self.lead_provider.supports_embeddings() || self.worker_provider.supports_embeddings()
448 }
449
450 async fn create_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
451 if self.lead_provider.supports_embeddings() {
453 self.lead_provider.create_embeddings(texts).await
454 } else if self.worker_provider.supports_embeddings() {
455 self.worker_provider.create_embeddings(texts).await
456 } else {
457 Err(ProviderError::ExecutionError(
458 "Neither lead nor worker provider supports embeddings".to_string(),
459 ))
460 }
461 }
462
463 fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
465 Some(self)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::conversation::message::{Message, MessageContent};
473 use crate::providers::base::{ProviderMetadata, ProviderUsage, Usage};
474 use chrono::Utc;
475 use rmcp::model::{AnnotateAble, RawTextContent, Role};
476
477 #[derive(Clone)]
478 struct MockProvider {
479 name: String,
480 model_config: ModelConfig,
481 }
482
483 #[async_trait]
484 impl Provider for MockProvider {
485 fn metadata() -> ProviderMetadata {
486 ProviderMetadata::empty()
487 }
488
489 fn get_name(&self) -> &str {
490 "mock-lead"
491 }
492
493 fn get_model_config(&self) -> ModelConfig {
494 self.model_config.clone()
495 }
496
497 async fn complete_with_model(
498 &self,
499 _model_config: &ModelConfig,
500 _system: &str,
501 _messages: &[Message],
502 _tools: &[Tool],
503 ) -> Result<(Message, ProviderUsage), ProviderError> {
504 Ok((
505 Message::new(
506 Role::Assistant,
507 Utc::now().timestamp(),
508 vec![MessageContent::Text(
509 RawTextContent {
510 text: format!("Response from {}", self.name),
511 meta: None,
512 }
513 .no_annotation(),
514 )],
515 ),
516 ProviderUsage::new(self.name.clone(), Usage::default()),
517 ))
518 }
519 }
520
521 #[tokio::test]
522 async fn test_lead_worker_switching() {
523 let lead_provider = Arc::new(MockProvider {
524 name: "lead".to_string(),
525 model_config: ModelConfig::new_or_fail("lead-model"),
526 });
527
528 let worker_provider = Arc::new(MockProvider {
529 name: "worker".to_string(),
530 model_config: ModelConfig::new_or_fail("worker-model"),
531 });
532
533 let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(3));
534
535 for i in 0..3 {
537 let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
538 assert_eq!(usage.model, "lead");
539 assert_eq!(provider.get_turn_count().await, i + 1);
540 assert!(!provider.is_in_fallback_mode().await);
541 }
542
543 for i in 3..6 {
545 let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
546 assert_eq!(usage.model, "worker");
547 assert_eq!(provider.get_turn_count().await, i + 1);
548 assert!(!provider.is_in_fallback_mode().await);
549 }
550
551 provider.reset_turn_count().await;
553 assert_eq!(provider.get_turn_count().await, 0);
554 assert_eq!(provider.get_failure_count().await, 0);
555 assert!(!provider.is_in_fallback_mode().await);
556
557 let (_message, usage) = provider.complete("system", &[], &[]).await.unwrap();
558 assert_eq!(usage.model, "lead");
559 }
560
561 #[tokio::test]
562 async fn test_technical_failure_retry() {
563 let lead_provider = Arc::new(MockFailureProvider {
564 name: "lead".to_string(),
565 model_config: ModelConfig::new_or_fail("lead-model"),
566 should_fail: false, });
568
569 let worker_provider = Arc::new(MockFailureProvider {
570 name: "worker".to_string(),
571 model_config: ModelConfig::new_or_fail("worker-model"),
572 should_fail: true, });
574
575 let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
576
577 for _i in 0..2 {
579 let result = provider.complete("system", &[], &[]).await;
580 assert!(result.is_ok());
581 assert_eq!(result.unwrap().1.model, "lead");
582 assert!(!provider.is_in_fallback_mode().await);
583 }
584
585 let result = provider.complete("system", &[], &[]).await;
587 assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); assert_eq!(provider.get_failure_count().await, 0); assert!(!provider.is_in_fallback_mode().await); let result = provider.complete("system", &[], &[]).await;
594 assert!(result.is_ok()); assert_eq!(result.unwrap().1.model, "lead"); assert_eq!(provider.get_failure_count().await, 0); assert!(!provider.is_in_fallback_mode().await); }
599
600 #[tokio::test]
601 async fn test_fallback_on_task_failures() {
602 let lead_provider = Arc::new(MockFailureProvider {
606 name: "lead".to_string(),
607 model_config: ModelConfig::new_or_fail("lead-model"),
608 should_fail: false,
609 });
610
611 let worker_provider = Arc::new(MockFailureProvider {
612 name: "worker".to_string(),
613 model_config: ModelConfig::new_or_fail("worker-model"),
614 should_fail: false,
615 });
616
617 let provider = LeadWorkerProvider::new(lead_provider, worker_provider, Some(2));
618
619 {
621 let mut in_fallback = provider.in_fallback_mode.lock().await;
622 *in_fallback = true;
623 let mut fallback_remaining = provider.fallback_remaining.lock().await;
624 *fallback_remaining = 2;
625 let mut turn_count = provider.turn_count.lock().await;
626 *turn_count = 4; }
628
629 let result = provider.complete("system", &[], &[]).await;
631 assert!(result.is_ok());
632 assert_eq!(result.unwrap().1.model, "lead");
633 assert!(provider.is_in_fallback_mode().await);
634
635 let result = provider.complete("system", &[], &[]).await;
637 assert!(result.is_ok());
638 assert_eq!(result.unwrap().1.model, "lead");
639 assert!(!provider.is_in_fallback_mode().await); }
641
642 #[derive(Clone)]
643 struct MockFailureProvider {
644 name: String,
645 model_config: ModelConfig,
646 should_fail: bool,
647 }
648
649 #[async_trait]
650 impl Provider for MockFailureProvider {
651 fn metadata() -> ProviderMetadata {
652 ProviderMetadata::empty()
653 }
654
655 fn get_name(&self) -> &str {
656 "mock-lead"
657 }
658
659 fn get_model_config(&self) -> ModelConfig {
660 self.model_config.clone()
661 }
662
663 async fn complete_with_model(
664 &self,
665 _model_config: &ModelConfig,
666 _system: &str,
667 _messages: &[Message],
668 _tools: &[Tool],
669 ) -> Result<(Message, ProviderUsage), ProviderError> {
670 if self.should_fail {
671 Err(ProviderError::ExecutionError(
672 "Simulated failure".to_string(),
673 ))
674 } else {
675 Ok((
676 Message::new(
677 Role::Assistant,
678 Utc::now().timestamp(),
679 vec![MessageContent::Text(
680 RawTextContent {
681 text: format!("Response from {}", self.name),
682 meta: None,
683 }
684 .no_annotation(),
685 )],
686 ),
687 ProviderUsage::new(self.name.clone(), Usage::default()),
688 ))
689 }
690 }
691 }
692}