embacle_server/runner/
multiplex.rs1use std::sync::Arc;
8use std::time::Instant;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{ChatMessage, ChatRequest, ResponseFormat, RunnerError};
12
13use crate::state::SharedState;
14
15#[derive(Debug, Clone, Default)]
17pub struct MultiplexParams {
18 pub temperature: Option<f32>,
20 pub max_tokens: Option<u32>,
22 pub top_p: Option<f32>,
24 pub stop: Option<Vec<String>>,
26 pub response_format: Option<ResponseFormat>,
28}
29
30#[derive(Debug)]
32pub struct MultiplexResult {
33 pub responses: Vec<ProviderResponse>,
35 pub summary: String,
37}
38
39#[derive(Debug)]
41pub struct ProviderResponse {
42 pub provider: String,
44 pub content: Option<String>,
46 pub model: Option<String>,
48 pub error: Option<String>,
50 pub duration_ms: u64,
52}
53
54pub struct MultiplexEngine {
56 state: SharedState,
57}
58
59impl MultiplexEngine {
60 pub fn new(state: &SharedState) -> Self {
62 Self {
63 state: Arc::clone(state),
64 }
65 }
66
67 pub async fn execute(
72 &self,
73 messages: &[ChatMessage],
74 providers: &[CliRunnerType],
75 params: &MultiplexParams,
76 ) -> Result<MultiplexResult, RunnerError> {
77 let mut handles = Vec::with_capacity(providers.len());
78
79 for &provider in providers {
80 let state = Arc::clone(&self.state);
81 let messages = messages.to_vec();
82 let params = params.clone();
83
84 handles.push(tokio::spawn(async move {
85 dispatch_single(state, provider, messages, ¶ms).await
86 }));
87 }
88
89 let mut responses = Vec::with_capacity(handles.len());
90 for handle in handles {
91 match handle.await {
92 Ok(resp) => responses.push(resp),
93 Err(e) => responses.push(ProviderResponse {
94 provider: "unknown".to_owned(),
95 content: None,
96 model: None,
97 error: Some(format!("Task join error: {e}")),
98 duration_ms: 0,
99 }),
100 }
101 }
102
103 let summary = build_summary(&responses);
104 Ok(MultiplexResult { responses, summary })
105 }
106}
107
108async fn dispatch_single(
110 state: SharedState,
111 provider: CliRunnerType,
112 messages: Vec<ChatMessage>,
113 params: &MultiplexParams,
114) -> ProviderResponse {
115 let start = Instant::now();
116
117 let runner = state.get_runner(provider).await;
118
119 let runner = match runner {
120 Ok(r) => r,
121 Err(e) => {
122 return ProviderResponse {
123 provider: provider.to_string(),
124 content: None,
125 model: None,
126 error: Some(e.to_string()),
127 duration_ms: elapsed_ms(start),
128 };
129 }
130 };
131
132 let mut request = ChatRequest::new(messages);
133 request.temperature = params.temperature;
134 request.max_tokens = params.max_tokens;
135 request.top_p = params.top_p;
136 request.stop.clone_from(¶ms.stop);
137 request.response_format.clone_from(¶ms.response_format);
138 match runner.complete(&request).await {
139 Ok(response) => ProviderResponse {
140 provider: provider.to_string(),
141 content: Some(response.content),
142 model: Some(response.model),
143 error: None,
144 duration_ms: elapsed_ms(start),
145 },
146 Err(e) => ProviderResponse {
147 provider: provider.to_string(),
148 content: None,
149 model: None,
150 error: Some(e.to_string()),
151 duration_ms: elapsed_ms(start),
152 },
153 }
154}
155
156fn build_summary(responses: &[ProviderResponse]) -> String {
158 let total = responses.len();
159 let succeeded = responses.iter().filter(|r| r.content.is_some()).count();
160 let failed = total - succeeded;
161 format!("{succeeded} succeeded, {failed} failed out of {total} providers")
162}
163
164fn elapsed_ms(start: Instant) -> u64 {
166 let millis = start.elapsed().as_millis();
167 u64::try_from(millis).unwrap_or(u64::MAX)
168}