llama-rs 0.17.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
//! Agent abstraction — one impl per transport.
//!
//! The trait is `dyn`-compatible: each `chat` call returns a boxed
//! `Stream` of token deltas rather than using `async fn in trait`. This
//! keeps the orchestrator code uniform across `GrpcAgent` and
//! `OpenAiHttpAgent`.

use std::pin::Pin;

use futures::stream::Stream;
use serde::{Deserialize, Serialize};

use crate::council::config::SamplingConfig;
use crate::council::event::ExpertId;

/// One message in a chat request.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatMessage {
    pub role: ChatRole,
    pub content: String,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
    System,
    User,
    Assistant,
}

impl ChatRole {
    pub fn as_wire_str(self) -> &'static str {
        match self {
            ChatRole::System => "system",
            ChatRole::User => "user",
            ChatRole::Assistant => "assistant",
        }
    }
}

/// Request sent to one agent for one round.
#[derive(Debug, Clone, PartialEq)]
pub struct ChatRequest {
    pub model: String,
    pub messages: Vec<ChatMessage>,
    pub sampling: SamplingConfig,
    pub request_id: Option<String>,
}

/// One streaming chunk from an agent.
#[derive(Debug, Clone, PartialEq)]
pub struct ChatToken {
    pub text: String,
    pub finished: bool,
    pub finish_reason: Option<String>,
}

#[derive(Debug, thiserror::Error)]
pub enum AgentError {
    #[error("agent {agent_id} timed out after {timeout_ms} ms")]
    Timeout { agent_id: ExpertId, timeout_ms: u64 },

    #[error("agent {agent_id} stream error: {message}")]
    Stream { agent_id: ExpertId, message: String },

    #[error("agent {agent_id} returned model name `{actual}`; expected `{expected}`")]
    ModelMismatch {
        agent_id: ExpertId,
        expected: String,
        actual: String,
    },

    #[error("agent {agent_id} transport error: {message}")]
    Transport { agent_id: ExpertId, message: String },
}

/// Boxed stream of token results — the dyn-compatible return shape.
pub type ChatTokenStream =
    Pin<Box<dyn Stream<Item = Result<ChatToken, AgentError>> + Send + 'static>>;

/// One network-reachable agent. Implementations wrap a transport (gRPC,
/// OpenAI HTTP, …) and produce a stream of token deltas per call.
pub trait Agent: Send + Sync {
    /// Stable id for transcripts ("A", "B", …).
    fn id(&self) -> &ExpertId;

    /// Model name advertised by the agent (used to label streams and to
    /// verify the gRPC `Hello` handshake).
    fn model(&self) -> &str;

    /// Configured per-call timeout in milliseconds.
    fn timeout_ms(&self) -> u64;

    /// Open a streaming chat call.
    fn chat(&self, request: ChatRequest) -> ChatTokenStream;
}

#[cfg(test)]
pub mod testing {
    //! Mock agents for unit tests.
    //!
    //! `MockAgent` returns a scripted sequence of `ChatToken`s (or one
    //! pre-baked error) from each `chat` call. Each call consumes the
    //! next script entry; if the scripts are exhausted the agent yields
    //! an empty stream.

    use super::*;
    use std::sync::Mutex;

    use futures::stream;

    /// A scripted plan for a single `chat` call.
    #[derive(Debug, Clone)]
    pub enum Script {
        /// Yield these tokens then close successfully.
        Tokens(Vec<&'static str>),
        /// Immediately yield an `AgentError::Stream`.
        Error(&'static str),
    }

    pub struct MockAgent {
        pub id: ExpertId,
        pub model: String,
        pub timeout_ms: u64,
        scripts: Mutex<Vec<Script>>,
    }

    impl MockAgent {
        pub fn new(id: &str, model: &str, scripts: Vec<Script>) -> Self {
            Self {
                id: id.into(),
                model: model.into(),
                timeout_ms: 30_000,
                scripts: Mutex::new(scripts),
            }
        }

        /// Number of scripts remaining (useful for assertions).
        pub fn remaining_scripts(&self) -> usize {
            self.scripts.lock().unwrap().len()
        }
    }

    impl Agent for MockAgent {
        fn id(&self) -> &ExpertId {
            &self.id
        }

        fn model(&self) -> &str {
            &self.model
        }

        fn timeout_ms(&self) -> u64 {
            self.timeout_ms
        }

        fn chat(&self, _request: ChatRequest) -> ChatTokenStream {
            let next = {
                let mut guard = self.scripts.lock().unwrap();
                if guard.is_empty() {
                    None
                } else {
                    Some(guard.remove(0))
                }
            };

            match next {
                None => Box::pin(stream::empty()),
                Some(Script::Tokens(toks)) => {
                    let last_idx = toks.len().saturating_sub(1);
                    let items: Vec<Result<ChatToken, AgentError>> = toks
                        .iter()
                        .enumerate()
                        .map(|(i, t)| {
                            Ok(ChatToken {
                                text: (*t).to_string(),
                                finished: i == last_idx,
                                finish_reason: if i == last_idx {
                                    Some("stop".into())
                                } else {
                                    None
                                },
                            })
                        })
                        .collect();
                    Box::pin(stream::iter(items))
                }
                Some(Script::Error(msg)) => {
                    let id = self.id.clone();
                    let err = AgentError::Stream {
                        agent_id: id,
                        message: msg.into(),
                    };
                    Box::pin(stream::iter(vec![Err(err)]))
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::testing::{MockAgent, Script};
    use super::*;
    use futures::StreamExt;

    fn req() -> ChatRequest {
        ChatRequest {
            model: "test-model".into(),
            messages: vec![ChatMessage {
                role: ChatRole::User,
                content: "hi".into(),
            }],
            sampling: SamplingConfig::default(),
            request_id: None,
        }
    }

    #[tokio::test]
    async fn mock_agent_yields_scripted_tokens_in_order() {
        let agent = MockAgent::new(
            "A",
            "test-model",
            vec![Script::Tokens(vec!["hello", " ", "world"])],
        );
        let mut stream = agent.chat(req());
        let mut texts = Vec::new();
        let mut last_finished = false;
        while let Some(item) = stream.next().await {
            let tok = item.expect("token ok");
            last_finished = tok.finished;
            texts.push(tok.text);
        }
        assert_eq!(texts, vec!["hello", " ", "world"]);
        assert!(last_finished, "final token must have finished=true");
    }

    #[tokio::test]
    async fn mock_agent_yields_error_when_scripted() {
        let agent = MockAgent::new("B", "test-model", vec![Script::Error("kaboom")]);
        let mut stream = agent.chat(req());
        let item = stream.next().await.expect("one item");
        match item {
            Err(AgentError::Stream { ref agent_id, ref message }) => {
                assert_eq!(agent_id, "B");
                assert!(message.contains("kaboom"));
            }
            other => panic!("expected Stream error, got {other:?}"),
        }
        assert!(stream.next().await.is_none());
    }

    #[tokio::test]
    async fn mock_agent_consumes_one_script_per_call() {
        let agent = MockAgent::new(
            "A",
            "m",
            vec![
                Script::Tokens(vec!["round0"]),
                Script::Tokens(vec!["round1"]),
            ],
        );
        assert_eq!(agent.remaining_scripts(), 2);

        let _ = agent.chat(req()).collect::<Vec<_>>().await;
        assert_eq!(agent.remaining_scripts(), 1);

        let _ = agent.chat(req()).collect::<Vec<_>>().await;
        assert_eq!(agent.remaining_scripts(), 0);

        // Third call yields empty stream.
        let third: Vec<_> = agent.chat(req()).collect().await;
        assert!(third.is_empty());
    }

    #[tokio::test]
    async fn agent_metadata_is_accessible() {
        let agent = MockAgent::new("X", "qwen3", vec![]);
        assert_eq!(agent.id(), "X");
        assert_eq!(agent.model(), "qwen3");
        assert_eq!(agent.timeout_ms(), 30_000);
    }
}