llama-rs 0.17.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
//! gRPC agent — orchestrator-side client wrapping the generated tonic
//! `Council` client.

use futures::stream::StreamExt;

use crate::council::agent::{Agent, AgentError, ChatRequest, ChatToken, ChatTokenStream};
use crate::council::event::ExpertId;
use crate::council::proto::council_client::CouncilClient;
use crate::council::proto::{
    ChatMessage as ProtoChatMessage, ChatRequest as ProtoChatRequest,
    SamplingParams as ProtoSampling,
};

/// Client to a remote llama-rs council agent over gRPC.
#[derive(Debug)]
pub struct GrpcAgent {
    id: ExpertId,
    model: String,
    timeout_ms: u64,
    /// `tonic::transport::Channel` is cheaply clonable (it's a handle
    /// to a connection pool); we clone it per call instead of locking.
    client: CouncilClient<tonic::transport::Channel>,
}

impl GrpcAgent {
    /// Connect to a council agent.
    ///
    /// `endpoint` must be `grpc://host:port`. The `grpc://` prefix is
    /// stripped and replaced with `http://` for tonic's URI parser.
    pub async fn connect(
        id: impl Into<ExpertId>,
        model: impl Into<String>,
        endpoint: &str,
        timeout_ms: u64,
    ) -> Result<Self, AgentError> {
        let id_str = id.into();
        let stripped = endpoint
            .strip_prefix("grpc://")
            .map(|s| format!("http://{s}"))
            .unwrap_or_else(|| endpoint.to_string());
        let channel = tonic::transport::Channel::from_shared(stripped)
            .map_err(|e| AgentError::Transport {
                agent_id: id_str.clone(),
                message: format!("invalid endpoint: {e}"),
            })?
            .connect()
            .await
            .map_err(|e| AgentError::Transport {
                agent_id: id_str.clone(),
                message: format!("connect failed: {e}"),
            })?;

        Ok(Self {
            id: id_str,
            model: model.into(),
            timeout_ms,
            client: CouncilClient::new(channel),
        })
    }
}

impl Agent for GrpcAgent {
    fn id(&self) -> &ExpertId {
        &self.id
    }

    fn model(&self) -> &str {
        &self.model
    }

    fn timeout_ms(&self) -> u64 {
        self.timeout_ms
    }

    fn chat(&self, request: ChatRequest) -> ChatTokenStream {
        let proto_req = build_proto_request(&request);
        let id = self.id.clone();
        let mut client = self.client.clone();

        // futures::stream::once + flatten lets us await the chat() RPC
        // (which itself returns a streaming response) and then surface
        // its items lazily as the caller polls.
        let setup = async move {
            match client.chat(proto_req).await {
                Ok(resp) => translate_stream(id, resp.into_inner()),
                Err(status) => single_err(id, format!("chat rpc failed: {status}")),
            }
        };

        Box::pin(futures::stream::once(setup).flatten())
    }
}

fn translate_stream(
    agent_id: ExpertId,
    inner: tonic::Streaming<crate::council::proto::ChatToken>,
) -> ChatTokenStream {
    Box::pin(inner.map(move |item| match item {
        Ok(tok) => Ok(ChatToken {
            text: tok.text,
            finished: tok.finished,
            finish_reason: tok.finish_reason,
        }),
        Err(status) => Err(AgentError::Stream {
            agent_id: agent_id.clone(),
            message: format!("stream item error: {status}"),
        }),
    }))
}

fn single_err(agent_id: ExpertId, message: String) -> ChatTokenStream {
    Box::pin(futures::stream::iter(vec![Err(AgentError::Transport {
        agent_id,
        message,
    })]))
}

fn build_proto_request(req: &ChatRequest) -> ProtoChatRequest {
    let messages = req
        .messages
        .iter()
        .map(|m| ProtoChatMessage {
            role: m.role.as_wire_str().to_string(),
            content: m.content.clone(),
        })
        .collect();

    ProtoChatRequest {
        model: req.model.clone(),
        messages,
        sampling: Some(ProtoSampling {
            temperature: req.sampling.temperature,
            top_p: req.sampling.top_p,
            max_tokens: req.sampling.max_tokens,
            seed: req.sampling.seed,
        }),
        request_id: req.request_id.clone(),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::council::agent::{ChatMessage, ChatRequest, ChatRole};
    use crate::council::config::SamplingConfig;
    use crate::council::proto::council_server::{Council as CouncilService, CouncilServer};
    use crate::council::proto::{
        ChatToken as ProtoChatToken, HelloRequest, HelloResponse,
    };
    use futures::stream::Stream;
    use std::pin::Pin;
    use std::time::Duration;
    use tokio::sync::oneshot;
    use tonic::{Request, Response, Status};

    /// Test-only Council impl returning a fixed token sequence.
    struct FixedAgent {
        tokens: Vec<&'static str>,
    }

    type ServerChatStream =
        Pin<Box<dyn Stream<Item = Result<ProtoChatToken, Status>> + Send + 'static>>;

    #[tonic::async_trait]
    impl CouncilService for FixedAgent {
        async fn hello(
            &self,
            _req: Request<HelloRequest>,
        ) -> Result<Response<HelloResponse>, Status> {
            Ok(Response::new(HelloResponse {
                server_version: "test".into(),
                model_name: "fixed-model".into(),
                context_window: 4096,
                supports_streaming: true,
            }))
        }

        type ChatStream = ServerChatStream;

        async fn chat(
            &self,
            _req: Request<ProtoChatRequest>,
        ) -> Result<Response<Self::ChatStream>, Status> {
            let toks: Vec<&'static str> = self.tokens.clone();
            let last_idx = toks.len().saturating_sub(1);
            let items: Vec<Result<ProtoChatToken, Status>> = toks
                .iter()
                .enumerate()
                .map(|(i, t)| {
                    Ok(ProtoChatToken {
                        text: (*t).to_string(),
                        finished: i == last_idx,
                        finish_reason: if i == last_idx {
                            Some("stop".into())
                        } else {
                            None
                        },
                    })
                })
                .collect();
            let stream: ServerChatStream = Box::pin(futures::stream::iter(items));
            Ok(Response::new(stream))
        }
    }

    async fn spawn_test_server(
        tokens: Vec<&'static str>,
    ) -> (std::net::SocketAddr, oneshot::Sender<()>) {
        // Bind with tokio and hand the listener to tonic directly —
        // avoids the "drop std listener, hope tonic rebinds" race.
        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();

        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
        let svc = CouncilServer::new(FixedAgent { tokens });

        let incoming = futures::stream::unfold(listener, |listener| async move {
            let res = listener
                .accept()
                .await
                .map(|(stream, _)| stream);
            Some((res, listener))
        });

        tokio::spawn(async move {
            let _ = tonic::transport::Server::builder()
                .add_service(svc)
                .serve_with_incoming_shutdown(incoming, async move {
                    let _ = shutdown_rx.await;
                })
                .await;
        });

        // Tiny grace so the server has entered its accept loop.
        tokio::time::sleep(Duration::from_millis(30)).await;
        (addr, shutdown_tx)
    }

    fn req() -> ChatRequest {
        ChatRequest {
            model: "fixed-model".into(),
            messages: vec![ChatMessage {
                role: ChatRole::User,
                content: "hi".into(),
            }],
            sampling: SamplingConfig::default(),
            request_id: None,
        }
    }

    #[tokio::test]
    async fn grpc_agent_streams_tokens_in_order() {
        let (addr, shutdown) = spawn_test_server(vec!["hello", " ", "world"]).await;
        let endpoint = format!("grpc://{addr}");
        let agent = GrpcAgent::connect("A", "fixed-model", &endpoint, 5_000)
            .await
            .expect("connect");

        let mut stream = agent.chat(req());
        let mut texts = Vec::new();
        let mut last_finished = false;
        while let Some(item) = stream.next().await {
            let tok = item.expect("ok");
            last_finished = tok.finished;
            texts.push(tok.text);
        }
        assert_eq!(texts, vec!["hello", " ", "world"]);
        assert!(last_finished);

        let _ = shutdown.send(());
    }

    #[tokio::test]
    async fn grpc_agent_reports_transport_error_for_bad_endpoint() {
        let err = GrpcAgent::connect("A", "m", "grpc://127.0.0.1:1", 100)
            .await
            .unwrap_err();
        match err {
            AgentError::Transport { agent_id, .. } => assert_eq!(agent_id, "A"),
            other => panic!("expected Transport, got {other:?}"),
        }
    }
}