use std::sync::{Arc, OnceLock, RwLock};
use serde::{Deserialize, Serialize};
use crate::stream::StreamEvent;
use crate::tasks::generate::GenerateRequest;
use crate::InferenceError;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RunnerResult {
pub text: String,
#[serde(default)]
pub tool_calls: Vec<crate::tasks::generate::ToolCall>,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum RunnerError {
#[error("runner declined: {0}")]
Declined(String),
#[error("runner failed: {0}")]
Failed(String),
}
impl From<RunnerError> for InferenceError {
fn from(value: RunnerError) -> Self {
InferenceError::InferenceFailed(value.to_string())
}
}
#[derive(Clone)]
pub struct EventEmitter {
tx: tokio::sync::mpsc::Sender<StreamEvent>,
}
impl EventEmitter {
pub(crate) fn new(tx: tokio::sync::mpsc::Sender<StreamEvent>) -> Self {
Self { tx }
}
pub async fn emit(&self, event: StreamEvent) {
let _ = self.tx.send(event).await;
}
pub fn is_closed(&self) -> bool {
self.tx.is_closed()
}
}
#[async_trait::async_trait]
pub trait InferenceRunner: Send + Sync {
async fn run(
&self,
request: GenerateRequest,
emitter: EventEmitter,
) -> Result<RunnerResult, RunnerError>;
}
fn runner_slot() -> &'static RwLock<Option<Arc<dyn InferenceRunner>>> {
static SLOT: OnceLock<RwLock<Option<Arc<dyn InferenceRunner>>>> = OnceLock::new();
SLOT.get_or_init(|| RwLock::new(None))
}
pub fn set_inference_runner(runner: Option<Arc<dyn InferenceRunner>>) {
let mut guard = runner_slot()
.write()
.expect("inference runner slot poisoned");
*guard = runner;
}
pub fn current_inference_runner() -> Option<Arc<dyn InferenceRunner>> {
runner_slot()
.read()
.expect("inference runner slot poisoned")
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
struct EchoRunner;
#[async_trait::async_trait]
impl InferenceRunner for EchoRunner {
async fn run(
&self,
request: GenerateRequest,
emitter: EventEmitter,
) -> Result<RunnerResult, RunnerError> {
let text = format!("echo:{}", request.prompt);
emitter.emit(StreamEvent::TextDelta(text.clone())).await;
emitter
.emit(StreamEvent::Done {
text: text.clone(),
tool_calls: vec![],
})
.await;
Ok(RunnerResult {
text,
tool_calls: vec![],
})
}
}
#[test]
fn slot_round_trips() {
set_inference_runner(None);
assert!(current_inference_runner().is_none());
set_inference_runner(Some(Arc::new(EchoRunner)));
assert!(current_inference_runner().is_some());
set_inference_runner(None);
assert!(current_inference_runner().is_none());
}
#[tokio::test]
async fn runner_can_emit_then_finish() {
let runner: Arc<dyn InferenceRunner> = Arc::new(EchoRunner);
let (tx, mut rx) = tokio::sync::mpsc::channel::<StreamEvent>(8);
let emitter = EventEmitter::new(tx);
let request = GenerateRequest {
prompt: "hi".into(),
..Default::default()
};
let result = runner.run(request, emitter).await.unwrap();
assert_eq!(result.text, "echo:hi");
let mut got = Vec::new();
while let Ok(evt) =
tokio::time::timeout(std::time::Duration::from_millis(20), rx.recv()).await
{
match evt {
Some(e) => got.push(e),
None => break,
}
}
assert_eq!(got.len(), 2);
matches!(got[0], StreamEvent::TextDelta(_));
matches!(got[1], StreamEvent::Done { .. });
}
}