Skip to main content

embacle_server/runner/
multiplex.rs

1// ABOUTME: Multiplex engine that fans out prompts to multiple providers concurrently
2// ABOUTME: Collects per-provider results with timing and produces an aggregated summary
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::sync::Arc;
8use std::time::Instant;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{ChatMessage, ChatRequest, RunnerError};
12
13use crate::state::SharedState;
14
15/// Aggregated result from dispatching a prompt to multiple providers
16#[derive(Debug)]
17pub struct MultiplexResult {
18    /// Individual per-provider responses
19    pub responses: Vec<ProviderResponse>,
20    /// Human-readable summary of the multiplex operation
21    pub summary: String,
22}
23
24/// Response from a single provider in a multiplex operation
25#[derive(Debug)]
26pub struct ProviderResponse {
27    /// Provider identifier
28    pub provider: String,
29    /// Response content (None on failure)
30    pub content: Option<String>,
31    /// Model used by the provider
32    pub model: Option<String>,
33    /// Error message (None on success)
34    pub error: Option<String>,
35    /// Wall-clock time in milliseconds
36    pub duration_ms: u64,
37}
38
39/// Engine that dispatches prompts to multiple embacle runners concurrently
40pub struct MultiplexEngine {
41    state: SharedState,
42}
43
44impl MultiplexEngine {
45    /// Create a new multiplex engine backed by the shared server state
46    pub fn new(state: &SharedState) -> Self {
47        Self {
48            state: Arc::clone(state),
49        }
50    }
51
52    /// Execute a prompt against all specified providers concurrently
53    ///
54    /// Each provider runs in its own tokio task. Failures in one provider
55    /// do not affect others — all results are collected and returned.
56    pub async fn execute(
57        &self,
58        messages: &[ChatMessage],
59        providers: &[CliRunnerType],
60        temperature: Option<f32>,
61        max_tokens: Option<u32>,
62    ) -> Result<MultiplexResult, RunnerError> {
63        let mut handles = Vec::with_capacity(providers.len());
64
65        for &provider in providers {
66            let state = Arc::clone(&self.state);
67            let messages = messages.to_vec();
68
69            handles.push(tokio::spawn(async move {
70                dispatch_single(state, provider, messages, temperature, max_tokens).await
71            }));
72        }
73
74        let mut responses = Vec::with_capacity(handles.len());
75        for handle in handles {
76            match handle.await {
77                Ok(resp) => responses.push(resp),
78                Err(e) => responses.push(ProviderResponse {
79                    provider: "unknown".to_owned(),
80                    content: None,
81                    model: None,
82                    error: Some(format!("Task join error: {e}")),
83                    duration_ms: 0,
84                }),
85            }
86        }
87
88        let summary = build_summary(&responses);
89        Ok(MultiplexResult { responses, summary })
90    }
91}
92
93/// Dispatch a single prompt to one provider and capture the result
94async fn dispatch_single(
95    state: SharedState,
96    provider: CliRunnerType,
97    messages: Vec<ChatMessage>,
98    temperature: Option<f32>,
99    max_tokens: Option<u32>,
100) -> ProviderResponse {
101    let start = Instant::now();
102
103    let runner = state.get_runner(provider).await;
104
105    let runner = match runner {
106        Ok(r) => r,
107        Err(e) => {
108            return ProviderResponse {
109                provider: provider.to_string(),
110                content: None,
111                model: None,
112                error: Some(e.to_string()),
113                duration_ms: elapsed_ms(start),
114            };
115        }
116    };
117
118    let mut request = ChatRequest::new(messages);
119    request.temperature = temperature;
120    request.max_tokens = max_tokens;
121    match runner.complete(&request).await {
122        Ok(response) => ProviderResponse {
123            provider: provider.to_string(),
124            content: Some(response.content),
125            model: Some(response.model),
126            error: None,
127            duration_ms: elapsed_ms(start),
128        },
129        Err(e) => ProviderResponse {
130            provider: provider.to_string(),
131            content: None,
132            model: None,
133            error: Some(e.to_string()),
134            duration_ms: elapsed_ms(start),
135        },
136    }
137}
138
139/// Build a human-readable summary from multiplex responses
140fn build_summary(responses: &[ProviderResponse]) -> String {
141    let total = responses.len();
142    let succeeded = responses.iter().filter(|r| r.content.is_some()).count();
143    let failed = total - succeeded;
144    format!("{succeeded} succeeded, {failed} failed out of {total} providers")
145}
146
147/// Convert elapsed time to milliseconds as u64
148fn elapsed_ms(start: Instant) -> u64 {
149    let millis = start.elapsed().as_millis();
150    u64::try_from(millis).unwrap_or(u64::MAX)
151}