embacle_server/runner/
multiplex.rs1use std::sync::Arc;
8use std::time::Instant;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{ChatMessage, ChatRequest, RunnerError};
12
13use crate::state::SharedState;
14
15#[derive(Debug)]
17pub struct MultiplexResult {
18 pub responses: Vec<ProviderResponse>,
20 pub summary: String,
22}
23
24#[derive(Debug)]
26pub struct ProviderResponse {
27 pub provider: String,
29 pub content: Option<String>,
31 pub model: Option<String>,
33 pub error: Option<String>,
35 pub duration_ms: u64,
37}
38
39pub struct MultiplexEngine {
41 state: SharedState,
42}
43
44impl MultiplexEngine {
45 pub fn new(state: &SharedState) -> Self {
47 Self {
48 state: Arc::clone(state),
49 }
50 }
51
52 pub async fn execute(
57 &self,
58 messages: &[ChatMessage],
59 providers: &[CliRunnerType],
60 temperature: Option<f32>,
61 max_tokens: Option<u32>,
62 ) -> Result<MultiplexResult, RunnerError> {
63 let mut handles = Vec::with_capacity(providers.len());
64
65 for &provider in providers {
66 let state = Arc::clone(&self.state);
67 let messages = messages.to_vec();
68
69 handles.push(tokio::spawn(async move {
70 dispatch_single(state, provider, messages, temperature, max_tokens).await
71 }));
72 }
73
74 let mut responses = Vec::with_capacity(handles.len());
75 for handle in handles {
76 match handle.await {
77 Ok(resp) => responses.push(resp),
78 Err(e) => responses.push(ProviderResponse {
79 provider: "unknown".to_owned(),
80 content: None,
81 model: None,
82 error: Some(format!("Task join error: {e}")),
83 duration_ms: 0,
84 }),
85 }
86 }
87
88 let summary = build_summary(&responses);
89 Ok(MultiplexResult { responses, summary })
90 }
91}
92
93async fn dispatch_single(
95 state: SharedState,
96 provider: CliRunnerType,
97 messages: Vec<ChatMessage>,
98 temperature: Option<f32>,
99 max_tokens: Option<u32>,
100) -> ProviderResponse {
101 let start = Instant::now();
102
103 let runner = state.get_runner(provider).await;
104
105 let runner = match runner {
106 Ok(r) => r,
107 Err(e) => {
108 return ProviderResponse {
109 provider: provider.to_string(),
110 content: None,
111 model: None,
112 error: Some(e.to_string()),
113 duration_ms: elapsed_ms(start),
114 };
115 }
116 };
117
118 let mut request = ChatRequest::new(messages);
119 request.temperature = temperature;
120 request.max_tokens = max_tokens;
121 match runner.complete(&request).await {
122 Ok(response) => ProviderResponse {
123 provider: provider.to_string(),
124 content: Some(response.content),
125 model: Some(response.model),
126 error: None,
127 duration_ms: elapsed_ms(start),
128 },
129 Err(e) => ProviderResponse {
130 provider: provider.to_string(),
131 content: None,
132 model: None,
133 error: Some(e.to_string()),
134 duration_ms: elapsed_ms(start),
135 },
136 }
137}
138
139fn build_summary(responses: &[ProviderResponse]) -> String {
141 let total = responses.len();
142 let succeeded = responses.iter().filter(|r| r.content.is_some()).count();
143 let failed = total - succeeded;
144 format!("{succeeded} succeeded, {failed} failed out of {total} providers")
145}
146
147fn elapsed_ms(start: Instant) -> u64 {
149 let millis = start.elapsed().as_millis();
150 u64::try_from(millis).unwrap_or(u64::MAX)
151}