use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::process::Command;
use tokio::sync::Mutex;
use tracing::{debug, warn};
use crate::config::RunnerConfig;
use crate::process::{run_cli_command, CliOutput};
use crate::types::RunnerError;
pub const MAX_OUTPUT_BYTES: usize = 50 * 1024 * 1024;
pub const HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(10);
pub const HEALTH_CHECK_MAX_OUTPUT: usize = 4096;
pub struct CliRunnerBase {
pub(crate) config: RunnerConfig,
pub(crate) default_model: String,
pub(crate) available_models: Vec<String>,
pub(crate) session_ids: Arc<Mutex<HashMap<String, String>>>,
}
impl CliRunnerBase {
pub fn new(config: RunnerConfig, default_model: &str, fallback_models: &[&str]) -> Self {
let resolved_model = config
.model
.clone()
.unwrap_or_else(|| default_model.to_owned());
let available_models = fallback_models.iter().map(|s| (*s).to_owned()).collect();
Self {
config,
default_model: resolved_model,
available_models,
session_ids: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn default_model(&self) -> &str {
&self.default_model
}
pub fn available_models(&self) -> &[String] {
&self.available_models
}
pub async fn set_session(&self, key: &str, session_id: &str) {
let mut sessions = self.session_ids.lock().await;
sessions.insert(key.to_owned(), session_id.to_owned());
}
pub async fn get_session(&self, key: &str) -> Option<String> {
let sessions = self.session_ids.lock().await;
sessions.get(key).cloned()
}
pub async fn health_check(&self, runner_name: &str) -> Result<bool, RunnerError> {
let mut cmd = Command::new(&self.config.binary_path);
cmd.arg("--version");
let output =
run_cli_command(&mut cmd, HEALTH_CHECK_TIMEOUT, HEALTH_CHECK_MAX_OUTPUT).await?;
if output.exit_code == 0 {
debug!("{runner_name} health check passed");
Ok(true)
} else {
warn!(
exit_code = output.exit_code,
"{runner_name} health check failed"
);
Ok(false)
}
}
pub fn check_exit_code(
&self,
output: &CliOutput,
runner_name: &str,
) -> Result<(), RunnerError> {
if output.exit_code == 0 {
return Ok(());
}
warn!(
exit_code = output.exit_code,
stdout_len = output.stdout.len(),
stderr_len = output.stderr.len(),
"{runner_name} CLI failed"
);
let stderr = String::from_utf8_lossy(&output.stderr);
let first_line = stderr.lines().next().unwrap_or("(no output)");
Err(RunnerError::external_service(
runner_name,
format!(
"{runner_name} exited with code {}: {first_line}",
output.exit_code
),
))
}
}
#[macro_export]
macro_rules! delegate_provider_base {
($runner_name:expr, $display_name:expr, $caps:expr) => {
fn name(&self) -> &'static str {
$runner_name
}
fn display_name(&self) -> &str {
$display_name
}
fn capabilities(&self) -> $crate::types::LlmCapabilities {
$caps
}
fn default_model(&self) -> &str {
self.base.default_model()
}
fn available_models(&self) -> &[String] {
self.base.available_models()
}
fn health_check<'life0, 'async_trait>(
&'life0 self,
) -> ::core::pin::Pin<
Box<
dyn ::core::future::Future<Output = Result<bool, $crate::types::RunnerError>>
+ ::core::marker::Send
+ 'async_trait,
>,
>
where
'life0: 'async_trait,
Self: 'async_trait,
{
Box::pin(async move { self.base.health_check($runner_name).await })
}
};
}