Skip to main content

embacle_server/runner/
multiplex.rs

1// ABOUTME: Multiplex engine that fans out prompts to multiple providers concurrently
2// ABOUTME: Collects per-provider results with timing and produces an aggregated summary
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::sync::Arc;
8use std::time::Instant;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{ChatMessage, ChatRequest, ResponseFormat, RunnerError};
12
13use crate::state::SharedState;
14
15/// Optional request parameters forwarded to each provider in a multiplex dispatch
16#[derive(Debug, Clone, Default)]
17pub struct MultiplexParams {
18    /// Temperature for response randomness
19    pub temperature: Option<f32>,
20    /// Maximum tokens to generate
21    pub max_tokens: Option<u32>,
22    /// Nucleus sampling parameter
23    pub top_p: Option<f32>,
24    /// Stop sequences that halt generation
25    pub stop: Option<Vec<String>>,
26    /// Response format control
27    pub response_format: Option<ResponseFormat>,
28}
29
30/// Aggregated result from dispatching a prompt to multiple providers
31#[derive(Debug)]
32pub struct MultiplexResult {
33    /// Individual per-provider responses
34    pub responses: Vec<ProviderResponse>,
35    /// Human-readable summary of the multiplex operation
36    pub summary: String,
37}
38
39/// Response from a single provider in a multiplex operation
40#[derive(Debug)]
41pub struct ProviderResponse {
42    /// Provider identifier
43    pub provider: String,
44    /// Response content (None on failure)
45    pub content: Option<String>,
46    /// Model used by the provider
47    pub model: Option<String>,
48    /// Error message (None on success)
49    pub error: Option<String>,
50    /// Wall-clock time in milliseconds
51    pub duration_ms: u64,
52}
53
54/// Engine that dispatches prompts to multiple embacle runners concurrently
55pub struct MultiplexEngine {
56    state: SharedState,
57}
58
59impl MultiplexEngine {
60    /// Create a new multiplex engine backed by the shared server state
61    pub fn new(state: &SharedState) -> Self {
62        Self {
63            state: Arc::clone(state),
64        }
65    }
66
67    /// Execute a prompt against all specified providers concurrently
68    ///
69    /// Each provider runs in its own tokio task. Failures in one provider
70    /// do not affect others — all results are collected and returned.
71    pub async fn execute(
72        &self,
73        messages: &[ChatMessage],
74        providers: &[CliRunnerType],
75        params: &MultiplexParams,
76    ) -> Result<MultiplexResult, RunnerError> {
77        let mut handles = Vec::with_capacity(providers.len());
78
79        for &provider in providers {
80            let state = Arc::clone(&self.state);
81            let messages = messages.to_vec();
82            let params = params.clone();
83
84            handles.push(tokio::spawn(async move {
85                dispatch_single(state, provider, messages, &params).await
86            }));
87        }
88
89        let mut responses = Vec::with_capacity(handles.len());
90        for handle in handles {
91            match handle.await {
92                Ok(resp) => responses.push(resp),
93                Err(e) => responses.push(ProviderResponse {
94                    provider: "unknown".to_owned(),
95                    content: None,
96                    model: None,
97                    error: Some(format!("Task join error: {e}")),
98                    duration_ms: 0,
99                }),
100            }
101        }
102
103        let summary = build_summary(&responses);
104        Ok(MultiplexResult { responses, summary })
105    }
106}
107
108/// Dispatch a single prompt to one provider and capture the result
109async fn dispatch_single(
110    state: SharedState,
111    provider: CliRunnerType,
112    messages: Vec<ChatMessage>,
113    params: &MultiplexParams,
114) -> ProviderResponse {
115    let start = Instant::now();
116
117    let runner = state.get_runner(provider).await;
118
119    let runner = match runner {
120        Ok(r) => r,
121        Err(e) => {
122            return ProviderResponse {
123                provider: provider.to_string(),
124                content: None,
125                model: None,
126                error: Some(e.to_string()),
127                duration_ms: elapsed_ms(start),
128            };
129        }
130    };
131
132    let mut request = ChatRequest::new(messages);
133    request.temperature = params.temperature;
134    request.max_tokens = params.max_tokens;
135    request.top_p = params.top_p;
136    request.stop.clone_from(&params.stop);
137    request.response_format.clone_from(&params.response_format);
138    match runner.complete(&request).await {
139        Ok(response) => ProviderResponse {
140            provider: provider.to_string(),
141            content: Some(response.content),
142            model: Some(response.model),
143            error: None,
144            duration_ms: elapsed_ms(start),
145        },
146        Err(e) => ProviderResponse {
147            provider: provider.to_string(),
148            content: None,
149            model: None,
150            error: Some(e.to_string()),
151            duration_ms: elapsed_ms(start),
152        },
153    }
154}
155
156/// Build a human-readable summary from multiplex responses
157fn build_summary(responses: &[ProviderResponse]) -> String {
158    let total = responses.len();
159    let succeeded = responses.iter().filter(|r| r.content.is_some()).count();
160    let failed = total - succeeded;
161    format!("{succeeded} succeeded, {failed} failed out of {total} providers")
162}
163
164/// Convert elapsed time to milliseconds as u64
165fn elapsed_ms(start: Instant) -> u64 {
166    let millis = start.elapsed().as_millis();
167    u64::try_from(millis).unwrap_or(u64::MAX)
168}