use std::pin::Pin;
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use crate::council::config::SamplingConfig;
use crate::council::event::ExpertId;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
}
impl ChatRole {
pub fn as_wire_str(self) -> &'static str {
match self {
ChatRole::System => "system",
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub sampling: SamplingConfig,
pub request_id: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ChatToken {
pub text: String,
pub finished: bool,
pub finish_reason: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum AgentError {
#[error("agent {agent_id} timed out after {timeout_ms} ms")]
Timeout { agent_id: ExpertId, timeout_ms: u64 },
#[error("agent {agent_id} stream error: {message}")]
Stream { agent_id: ExpertId, message: String },
#[error("agent {agent_id} returned model name `{actual}`; expected `{expected}`")]
ModelMismatch {
agent_id: ExpertId,
expected: String,
actual: String,
},
#[error("agent {agent_id} transport error: {message}")]
Transport { agent_id: ExpertId, message: String },
}
pub type ChatTokenStream =
Pin<Box<dyn Stream<Item = Result<ChatToken, AgentError>> + Send + 'static>>;
pub trait Agent: Send + Sync {
fn id(&self) -> &ExpertId;
fn model(&self) -> &str;
fn timeout_ms(&self) -> u64;
fn chat(&self, request: ChatRequest) -> ChatTokenStream;
}
#[cfg(test)]
pub mod testing {
use super::*;
use std::sync::Mutex;
use futures::stream;
#[derive(Debug, Clone)]
pub enum Script {
Tokens(Vec<&'static str>),
Error(&'static str),
}
pub struct MockAgent {
pub id: ExpertId,
pub model: String,
pub timeout_ms: u64,
scripts: Mutex<Vec<Script>>,
}
impl MockAgent {
pub fn new(id: &str, model: &str, scripts: Vec<Script>) -> Self {
Self {
id: id.into(),
model: model.into(),
timeout_ms: 30_000,
scripts: Mutex::new(scripts),
}
}
pub fn remaining_scripts(&self) -> usize {
self.scripts.lock().unwrap().len()
}
}
impl Agent for MockAgent {
fn id(&self) -> &ExpertId {
&self.id
}
fn model(&self) -> &str {
&self.model
}
fn timeout_ms(&self) -> u64 {
self.timeout_ms
}
fn chat(&self, _request: ChatRequest) -> ChatTokenStream {
let next = {
let mut guard = self.scripts.lock().unwrap();
if guard.is_empty() {
None
} else {
Some(guard.remove(0))
}
};
match next {
None => Box::pin(stream::empty()),
Some(Script::Tokens(toks)) => {
let last_idx = toks.len().saturating_sub(1);
let items: Vec<Result<ChatToken, AgentError>> = toks
.iter()
.enumerate()
.map(|(i, t)| {
Ok(ChatToken {
text: (*t).to_string(),
finished: i == last_idx,
finish_reason: if i == last_idx {
Some("stop".into())
} else {
None
},
})
})
.collect();
Box::pin(stream::iter(items))
}
Some(Script::Error(msg)) => {
let id = self.id.clone();
let err = AgentError::Stream {
agent_id: id,
message: msg.into(),
};
Box::pin(stream::iter(vec![Err(err)]))
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::testing::{MockAgent, Script};
use super::*;
use futures::StreamExt;
fn req() -> ChatRequest {
ChatRequest {
model: "test-model".into(),
messages: vec![ChatMessage {
role: ChatRole::User,
content: "hi".into(),
}],
sampling: SamplingConfig::default(),
request_id: None,
}
}
#[tokio::test]
async fn mock_agent_yields_scripted_tokens_in_order() {
let agent = MockAgent::new(
"A",
"test-model",
vec![Script::Tokens(vec!["hello", " ", "world"])],
);
let mut stream = agent.chat(req());
let mut texts = Vec::new();
let mut last_finished = false;
while let Some(item) = stream.next().await {
let tok = item.expect("token ok");
last_finished = tok.finished;
texts.push(tok.text);
}
assert_eq!(texts, vec!["hello", " ", "world"]);
assert!(last_finished, "final token must have finished=true");
}
#[tokio::test]
async fn mock_agent_yields_error_when_scripted() {
let agent = MockAgent::new("B", "test-model", vec![Script::Error("kaboom")]);
let mut stream = agent.chat(req());
let item = stream.next().await.expect("one item");
match item {
Err(AgentError::Stream { ref agent_id, ref message }) => {
assert_eq!(agent_id, "B");
assert!(message.contains("kaboom"));
}
other => panic!("expected Stream error, got {other:?}"),
}
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn mock_agent_consumes_one_script_per_call() {
let agent = MockAgent::new(
"A",
"m",
vec![
Script::Tokens(vec!["round0"]),
Script::Tokens(vec!["round1"]),
],
);
assert_eq!(agent.remaining_scripts(), 2);
let _ = agent.chat(req()).collect::<Vec<_>>().await;
assert_eq!(agent.remaining_scripts(), 1);
let _ = agent.chat(req()).collect::<Vec<_>>().await;
assert_eq!(agent.remaining_scripts(), 0);
let third: Vec<_> = agent.chat(req()).collect().await;
assert!(third.is_empty());
}
#[tokio::test]
async fn agent_metadata_is_accessible() {
let agent = MockAgent::new("X", "qwen3", vec![]);
assert_eq!(agent.id(), "X");
assert_eq!(agent.model(), "qwen3");
assert_eq!(agent.timeout_ms(), 30_000);
}
}