use futures::stream::StreamExt;
use crate::council::agent::{Agent, AgentError, ChatRequest, ChatToken, ChatTokenStream};
use crate::council::event::ExpertId;
use crate::council::proto::council_client::CouncilClient;
use crate::council::proto::{
ChatMessage as ProtoChatMessage, ChatRequest as ProtoChatRequest,
SamplingParams as ProtoSampling,
};
#[derive(Debug)]
pub struct GrpcAgent {
id: ExpertId,
model: String,
timeout_ms: u64,
client: CouncilClient<tonic::transport::Channel>,
}
impl GrpcAgent {
pub async fn connect(
id: impl Into<ExpertId>,
model: impl Into<String>,
endpoint: &str,
timeout_ms: u64,
) -> Result<Self, AgentError> {
let id_str = id.into();
let stripped = endpoint
.strip_prefix("grpc://")
.map(|s| format!("http://{s}"))
.unwrap_or_else(|| endpoint.to_string());
let channel = tonic::transport::Channel::from_shared(stripped)
.map_err(|e| AgentError::Transport {
agent_id: id_str.clone(),
message: format!("invalid endpoint: {e}"),
})?
.connect()
.await
.map_err(|e| AgentError::Transport {
agent_id: id_str.clone(),
message: format!("connect failed: {e}"),
})?;
Ok(Self {
id: id_str,
model: model.into(),
timeout_ms,
client: CouncilClient::new(channel),
})
}
}
impl Agent for GrpcAgent {
fn id(&self) -> &ExpertId {
&self.id
}
fn model(&self) -> &str {
&self.model
}
fn timeout_ms(&self) -> u64 {
self.timeout_ms
}
fn chat(&self, request: ChatRequest) -> ChatTokenStream {
let proto_req = build_proto_request(&request);
let id = self.id.clone();
let mut client = self.client.clone();
let setup = async move {
match client.chat(proto_req).await {
Ok(resp) => translate_stream(id, resp.into_inner()),
Err(status) => single_err(id, format!("chat rpc failed: {status}")),
}
};
Box::pin(futures::stream::once(setup).flatten())
}
}
fn translate_stream(
agent_id: ExpertId,
inner: tonic::Streaming<crate::council::proto::ChatToken>,
) -> ChatTokenStream {
Box::pin(inner.map(move |item| match item {
Ok(tok) => Ok(ChatToken {
text: tok.text,
finished: tok.finished,
finish_reason: tok.finish_reason,
}),
Err(status) => Err(AgentError::Stream {
agent_id: agent_id.clone(),
message: format!("stream item error: {status}"),
}),
}))
}
fn single_err(agent_id: ExpertId, message: String) -> ChatTokenStream {
Box::pin(futures::stream::iter(vec![Err(AgentError::Transport {
agent_id,
message,
})]))
}
fn build_proto_request(req: &ChatRequest) -> ProtoChatRequest {
let messages = req
.messages
.iter()
.map(|m| ProtoChatMessage {
role: m.role.as_wire_str().to_string(),
content: m.content.clone(),
})
.collect();
ProtoChatRequest {
model: req.model.clone(),
messages,
sampling: Some(ProtoSampling {
temperature: req.sampling.temperature,
top_p: req.sampling.top_p,
max_tokens: req.sampling.max_tokens,
seed: req.sampling.seed,
}),
request_id: req.request_id.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::council::agent::{ChatMessage, ChatRequest, ChatRole};
use crate::council::config::SamplingConfig;
use crate::council::proto::council_server::{Council as CouncilService, CouncilServer};
use crate::council::proto::{
ChatToken as ProtoChatToken, HelloRequest, HelloResponse,
};
use futures::stream::Stream;
use std::pin::Pin;
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::{Request, Response, Status};
struct FixedAgent {
tokens: Vec<&'static str>,
}
type ServerChatStream =
Pin<Box<dyn Stream<Item = Result<ProtoChatToken, Status>> + Send + 'static>>;
#[tonic::async_trait]
impl CouncilService for FixedAgent {
async fn hello(
&self,
_req: Request<HelloRequest>,
) -> Result<Response<HelloResponse>, Status> {
Ok(Response::new(HelloResponse {
server_version: "test".into(),
model_name: "fixed-model".into(),
context_window: 4096,
supports_streaming: true,
}))
}
type ChatStream = ServerChatStream;
async fn chat(
&self,
_req: Request<ProtoChatRequest>,
) -> Result<Response<Self::ChatStream>, Status> {
let toks: Vec<&'static str> = self.tokens.clone();
let last_idx = toks.len().saturating_sub(1);
let items: Vec<Result<ProtoChatToken, Status>> = toks
.iter()
.enumerate()
.map(|(i, t)| {
Ok(ProtoChatToken {
text: (*t).to_string(),
finished: i == last_idx,
finish_reason: if i == last_idx {
Some("stop".into())
} else {
None
},
})
})
.collect();
let stream: ServerChatStream = Box::pin(futures::stream::iter(items));
Ok(Response::new(stream))
}
}
async fn spawn_test_server(
tokens: Vec<&'static str>,
) -> (std::net::SocketAddr, oneshot::Sender<()>) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let svc = CouncilServer::new(FixedAgent { tokens });
let incoming = futures::stream::unfold(listener, |listener| async move {
let res = listener
.accept()
.await
.map(|(stream, _)| stream);
Some((res, listener))
});
tokio::spawn(async move {
let _ = tonic::transport::Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(incoming, async move {
let _ = shutdown_rx.await;
})
.await;
});
tokio::time::sleep(Duration::from_millis(30)).await;
(addr, shutdown_tx)
}
fn req() -> ChatRequest {
ChatRequest {
model: "fixed-model".into(),
messages: vec![ChatMessage {
role: ChatRole::User,
content: "hi".into(),
}],
sampling: SamplingConfig::default(),
request_id: None,
}
}
#[tokio::test]
async fn grpc_agent_streams_tokens_in_order() {
let (addr, shutdown) = spawn_test_server(vec!["hello", " ", "world"]).await;
let endpoint = format!("grpc://{addr}");
let agent = GrpcAgent::connect("A", "fixed-model", &endpoint, 5_000)
.await
.expect("connect");
let mut stream = agent.chat(req());
let mut texts = Vec::new();
let mut last_finished = false;
while let Some(item) = stream.next().await {
let tok = item.expect("ok");
last_finished = tok.finished;
texts.push(tok.text);
}
assert_eq!(texts, vec!["hello", " ", "world"]);
assert!(last_finished);
let _ = shutdown.send(());
}
#[tokio::test]
async fn grpc_agent_reports_transport_error_for_bad_endpoint() {
let err = GrpcAgent::connect("A", "m", "grpc://127.0.0.1:1", 100)
.await
.unwrap_err();
match err {
AgentError::Transport { agent_id, .. } => assert_eq!(agent_id, "A"),
other => panic!("expected Transport, got {other:?}"),
}
}
}