Skip to main content

embacle_mcp/runner/
multiplex.rs

1// ABOUTME: Multiplex engine that fans out prompts to multiple providers concurrently
2// ABOUTME: Collects per-provider results with timing and produces an aggregated summary
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::sync::Arc;
8use std::time::Instant;
9
10use embacle::config::CliRunnerType;
11use embacle::types::{ChatMessage, ChatRequest, RunnerError};
12use serde::Serialize;
13
14use crate::state::SharedState;
15
16/// Aggregated result from dispatching a prompt to multiple providers
17#[derive(Debug, Serialize)]
18pub struct MultiplexResult {
19    /// Individual per-provider responses
20    pub responses: Vec<ProviderResponse>,
21    /// Human-readable summary of the multiplex operation
22    pub summary: String,
23}
24
25/// Response from a single provider in a multiplex operation
26#[derive(Debug, Serialize)]
27pub struct ProviderResponse {
28    /// Provider identifier
29    pub provider: String,
30    /// Response content (None on failure)
31    pub content: Option<String>,
32    /// Model used by the provider
33    pub model: Option<String>,
34    /// Error message (None on success)
35    pub error: Option<String>,
36    /// Wall-clock time in milliseconds
37    pub duration_ms: u64,
38}
39
40/// Engine that dispatches prompts to multiple embacle runners concurrently
41pub struct MultiplexEngine {
42    state: SharedState,
43}
44
45impl MultiplexEngine {
46    /// Create a new multiplex engine backed by the shared server state
47    pub const fn new(state: SharedState) -> Self {
48        Self { state }
49    }
50
51    /// Execute a prompt against all specified providers concurrently
52    ///
53    /// Each provider runs in its own tokio task. Failures in one provider
54    /// do not affect others — all results are collected and returned.
55    pub async fn execute(
56        &self,
57        messages: &[ChatMessage],
58        providers: &[CliRunnerType],
59    ) -> Result<MultiplexResult, RunnerError> {
60        let mut handles = Vec::with_capacity(providers.len());
61
62        for &provider in providers {
63            let state = Arc::clone(&self.state);
64            let messages = messages.to_vec();
65
66            handles.push(tokio::spawn(async move {
67                dispatch_single(state, provider, messages).await
68            }));
69        }
70
71        let mut responses = Vec::with_capacity(handles.len());
72        for handle in handles {
73            match handle.await {
74                Ok(resp) => responses.push(resp),
75                Err(e) => responses.push(ProviderResponse {
76                    provider: "unknown".to_owned(),
77                    content: None,
78                    model: None,
79                    error: Some(format!("Task join error: {e}")),
80                    duration_ms: 0,
81                }),
82            }
83        }
84
85        let summary = build_summary(&responses);
86        Ok(MultiplexResult { responses, summary })
87    }
88}
89
90/// Dispatch a single prompt to one provider and capture the result
91async fn dispatch_single(
92    state: SharedState,
93    provider: CliRunnerType,
94    messages: Vec<ChatMessage>,
95) -> ProviderResponse {
96    let start = Instant::now();
97
98    let runner = {
99        let state_guard = state.read().await;
100        state_guard.get_runner(provider).await
101    };
102
103    let runner = match runner {
104        Ok(r) => r,
105        Err(e) => {
106            return ProviderResponse {
107                provider: provider.to_string(),
108                content: None,
109                model: None,
110                error: Some(e.to_string()),
111                duration_ms: elapsed_ms(start),
112            };
113        }
114    };
115
116    let request = ChatRequest::new(messages);
117    match runner.complete(&request).await {
118        Ok(response) => ProviderResponse {
119            provider: provider.to_string(),
120            content: Some(response.content),
121            model: Some(response.model),
122            error: None,
123            duration_ms: elapsed_ms(start),
124        },
125        Err(e) => ProviderResponse {
126            provider: provider.to_string(),
127            content: None,
128            model: None,
129            error: Some(e.to_string()),
130            duration_ms: elapsed_ms(start),
131        },
132    }
133}
134
135/// Build a human-readable summary from multiplex responses
136fn build_summary(responses: &[ProviderResponse]) -> String {
137    let total = responses.len();
138    let succeeded = responses.iter().filter(|r| r.content.is_some()).count();
139    let failed = total - succeeded;
140    format!("{succeeded} succeeded, {failed} failed out of {total} providers")
141}
142
143/// Convert elapsed time to milliseconds as u64
144fn elapsed_ms(start: Instant) -> u64 {
145    let millis = start.elapsed().as_millis();
146    u64::try_from(millis).unwrap_or(u64::MAX)
147}