steer_grpc/
local_server.rs1use crate::grpc::error::GrpcError;
2use crate::grpc::server::AgentServiceImpl;
3type Result<T> = std::result::Result<T, GrpcError>;
4use std::sync::Arc;
5use steer_core::session::{SessionManager, SessionManagerConfig};
6use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
7use tokio::sync::oneshot;
8use tonic::transport::{Channel, Server};
9
10pub async fn create_local_channel(
13 session_manager: Arc<SessionManager>,
14) -> Result<(Channel, tokio::task::JoinHandle<()>)> {
15 let (tx, rx) = oneshot::channel();
17
18 #[cfg(not(test))]
20 let auth_storage = std::sync::Arc::new(
21 steer_core::auth::DefaultAuthStorage::new().map_err(|e| GrpcError::CoreError(e.into()))?,
22 );
23
24 #[cfg(test)]
25 let auth_storage = std::sync::Arc::new(steer_core::test_utils::InMemoryAuthStorage::new());
26
27 let llm_config_provider = steer_core::config::LlmConfigProvider::new(auth_storage);
28
29 let service = AgentServiceImpl::new(session_manager, llm_config_provider);
31 let svc = AgentServiceServer::new(service);
32
33 let server_handle: tokio::task::JoinHandle<()> = tokio::spawn(async move {
35 let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
37 let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
38 let local_addr = listener.local_addr().unwrap();
39
40 tx.send(local_addr).unwrap();
42
43 Server::builder()
45 .add_service(svc)
46 .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
47 .await
48 .expect("Failed to run localhost server");
49 });
50
51 let addr = rx
53 .await
54 .map_err(|e| GrpcError::ChannelError(format!("Failed to receive server address: {e}")))?;
55
56 let endpoint =
58 tonic::transport::Endpoint::try_from(format!("http://{addr}"))?.tcp_nodelay(true);
59 let channel = endpoint.connect().await?;
60
61 Ok((channel, server_handle))
62}
63
64pub async fn setup_local_grpc(
67 default_model: steer_core::api::Model,
68 session_db_path: Option<std::path::PathBuf>,
69) -> Result<(Channel, tokio::task::JoinHandle<()>)> {
70 let store_config = steer_core::utils::session::resolve_session_store_config(session_db_path)?;
72 let session_store =
73 steer_core::utils::session::create_session_store_with_config(store_config).await?;
74
75 let session_manager_config = SessionManagerConfig {
77 max_concurrent_sessions: 10,
78 default_model,
79 auto_persist: true,
80 };
81
82 let session_manager = Arc::new(SessionManager::new(session_store, session_manager_config));
83
84 let (channel, handle) = create_local_channel(session_manager).await?;
86
87 Ok((channel, handle))
88}