embacle_mcp/runner/
multiplex.rs1use std::sync::Arc;
8use std::time::Instant;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{ChatMessage, ChatRequest, RunnerError};
12use serde::Serialize;
13
14use crate::state::SharedState;
15
16#[derive(Debug, Serialize)]
18pub struct MultiplexResult {
19 pub responses: Vec<ProviderResponse>,
21 pub summary: String,
23}
24
25#[derive(Debug, Serialize)]
27pub struct ProviderResponse {
28 pub provider: String,
30 pub content: Option<String>,
32 pub model: Option<String>,
34 pub error: Option<String>,
36 pub duration_ms: u64,
38}
39
40pub struct MultiplexEngine {
42 state: SharedState,
43}
44
45impl MultiplexEngine {
46 pub const fn new(state: SharedState) -> Self {
48 Self { state }
49 }
50
51 pub async fn execute(
56 &self,
57 messages: &[ChatMessage],
58 providers: &[CliRunnerType],
59 ) -> Result<MultiplexResult, RunnerError> {
60 let mut handles = Vec::with_capacity(providers.len());
61
62 for &provider in providers {
63 let state = Arc::clone(&self.state);
64 let messages = messages.to_vec();
65
66 handles.push(tokio::spawn(async move {
67 dispatch_single(state, provider, messages).await
68 }));
69 }
70
71 let mut responses = Vec::with_capacity(handles.len());
72 for handle in handles {
73 match handle.await {
74 Ok(resp) => responses.push(resp),
75 Err(e) => responses.push(ProviderResponse {
76 provider: "unknown".to_owned(),
77 content: None,
78 model: None,
79 error: Some(format!("Task join error: {e}")),
80 duration_ms: 0,
81 }),
82 }
83 }
84
85 let summary = build_summary(&responses);
86 Ok(MultiplexResult { responses, summary })
87 }
88}
89
90async fn dispatch_single(
92 state: SharedState,
93 provider: CliRunnerType,
94 messages: Vec<ChatMessage>,
95) -> ProviderResponse {
96 let start = Instant::now();
97
98 let runner = {
99 let state_guard = state.read().await;
100 state_guard.get_runner(provider).await
101 };
102
103 let runner = match runner {
104 Ok(r) => r,
105 Err(e) => {
106 return ProviderResponse {
107 provider: provider.to_string(),
108 content: None,
109 model: None,
110 error: Some(e.to_string()),
111 duration_ms: elapsed_ms(start),
112 };
113 }
114 };
115
116 let request = ChatRequest::new(messages);
117 match runner.complete(&request).await {
118 Ok(response) => ProviderResponse {
119 provider: provider.to_string(),
120 content: Some(response.content),
121 model: Some(response.model),
122 error: None,
123 duration_ms: elapsed_ms(start),
124 },
125 Err(e) => ProviderResponse {
126 provider: provider.to_string(),
127 content: None,
128 model: None,
129 error: Some(e.to_string()),
130 duration_ms: elapsed_ms(start),
131 },
132 }
133}
134
135fn build_summary(responses: &[ProviderResponse]) -> String {
137 let total = responses.len();
138 let succeeded = responses.iter().filter(|r| r.content.is_some()).count();
139 let failed = total - succeeded;
140 format!("{succeeded} succeeded, {failed} failed out of {total} providers")
141}
142
143fn elapsed_ms(start: Instant) -> u64 {
145 let millis = start.elapsed().as_millis();
146 u64::try_from(millis).unwrap_or(u64::MAX)
147}