steer_grpc/
local_server.rs

1use crate::grpc::error::GrpcError;
2use crate::grpc::server::AgentServiceImpl;
3type Result<T> = std::result::Result<T, GrpcError>;
4use std::sync::Arc;
5use steer_core::session::{SessionManager, SessionManagerConfig};
6use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
7use tokio::sync::oneshot;
8use tonic::transport::{Channel, Server};
9
10/// Creates a localhost gRPC server and client channel
11/// This runs both server and client in the same process using localhost TCP
12pub async fn create_local_channel(
13    session_manager: Arc<SessionManager>,
14) -> Result<(Channel, tokio::task::JoinHandle<()>)> {
15    // Create a channel for the server's bound address
16    let (tx, rx) = oneshot::channel();
17
18    // Create LlmConfigProvider
19    #[cfg(not(test))]
20    let auth_storage = std::sync::Arc::new(
21        steer_core::auth::DefaultAuthStorage::new().map_err(|e| GrpcError::CoreError(e.into()))?,
22    );
23
24    #[cfg(test)]
25    let auth_storage = std::sync::Arc::new(steer_core::test_utils::InMemoryAuthStorage::new());
26
27    let llm_config_provider = steer_core::config::LlmConfigProvider::new(auth_storage);
28
29    // Create the service
30    let service = AgentServiceImpl::new(session_manager, llm_config_provider);
31    let svc = AgentServiceServer::new(service);
32
33    // Spawn the server with a listener on localhost
34    let server_handle: tokio::task::JoinHandle<()> = tokio::spawn(async move {
35        // Bind to port 0 to get a random available port
36        let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap();
37        let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
38        let local_addr = listener.local_addr().unwrap();
39
40        // Send the bound address back
41        tx.send(local_addr).unwrap();
42
43        // Run the server
44        Server::builder()
45            .add_service(svc)
46            .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
47            .await
48            .expect("Failed to run localhost server");
49    });
50
51    // Wait for the server to be ready and get its address
52    let addr = rx
53        .await
54        .map_err(|e| GrpcError::ChannelError(format!("Failed to receive server address: {e}")))?;
55
56    // Use tonic::transport::Endpoint for proper URI parsing
57    let endpoint =
58        tonic::transport::Endpoint::try_from(format!("http://{addr}"))?.tcp_nodelay(true);
59    let channel = endpoint.connect().await?;
60
61    Ok((channel, server_handle))
62}
63
64/// Creates a complete localhost gRPC setup for local mode
65/// Returns the channel and a handle to the server task
66pub async fn setup_local_grpc(
67    default_model: steer_core::api::Model,
68    session_db_path: Option<std::path::PathBuf>,
69) -> Result<(Channel, tokio::task::JoinHandle<()>)> {
70    // Create session store with the provided configuration
71    let store_config = steer_core::utils::session::resolve_session_store_config(session_db_path)?;
72    let session_store =
73        steer_core::utils::session::create_session_store_with_config(store_config).await?;
74
75    // Create global event channel (not used in local mode but required)
76    let session_manager_config = SessionManagerConfig {
77        max_concurrent_sessions: 10,
78        default_model,
79        auto_persist: true,
80    };
81
82    let session_manager = Arc::new(SessionManager::new(session_store, session_manager_config));
83
84    // Create localhost channel
85    let (channel, handle) = create_local_channel(session_manager).await?;
86
87    Ok((channel, handle))
88}