1use crate::grpc::error::GrpcError;
2type Result<T> = std::result::Result<T, GrpcError>;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use tokio::sync::oneshot;
6use tokio::task::JoinHandle;
7use tonic::transport::Server;
8use tracing::{error, info};
9
10use crate::grpc::server::AgentServiceImpl;
11use steer_core::auth::storage::AuthStorage;
12use steer_core::catalog::CatalogConfig;
13use steer_core::session::{SessionManager, SessionManagerConfig, SessionStore};
14use steer_proto::agent::v1::agent_service_server::AgentServiceServer;
15
16#[derive(Clone)]
18pub struct ServiceHostConfig {
19 pub db_path: std::path::PathBuf,
21 pub session_manager_config: SessionManagerConfig,
23 pub bind_addr: SocketAddr,
25 pub auth_storage: Arc<dyn AuthStorage>,
27 pub catalog_config: CatalogConfig,
29}
30
31impl std::fmt::Debug for ServiceHostConfig {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("ServiceHostConfig")
34 .field("db_path", &self.db_path)
35 .field("session_manager_config", &self.session_manager_config)
36 .field("bind_addr", &self.bind_addr)
37 .field("auth_storage", &"Arc<dyn AuthStorage>")
38 .field("catalog_config", &self.catalog_config)
39 .finish()
40 }
41}
42
43impl ServiceHostConfig {
44 pub fn new(
46 db_path: std::path::PathBuf,
47 session_manager_config: SessionManagerConfig,
48 bind_addr: SocketAddr,
49 ) -> Result<Self> {
50 let auth_storage = Arc::new(
51 steer_core::auth::DefaultAuthStorage::new()
52 .map_err(|e| GrpcError::CoreError(e.into()))?,
53 );
54
55 Ok(Self {
56 db_path,
57 session_manager_config,
58 bind_addr,
59 auth_storage,
60 catalog_config: CatalogConfig::default(),
61 })
62 }
63
64 pub fn with_catalog(
66 db_path: std::path::PathBuf,
67 session_manager_config: SessionManagerConfig,
68 bind_addr: SocketAddr,
69 catalog_config: CatalogConfig,
70 ) -> Result<Self> {
71 let auth_storage = Arc::new(
72 steer_core::auth::DefaultAuthStorage::new()
73 .map_err(|e| GrpcError::CoreError(e.into()))?,
74 );
75
76 Ok(Self {
77 db_path,
78 session_manager_config,
79 bind_addr,
80 auth_storage,
81 catalog_config,
82 })
83 }
84}
85
86pub struct ServiceHost {
89 session_manager: Arc<SessionManager>,
90 model_registry: Arc<steer_core::model_registry::ModelRegistry>,
91 provider_registry: Arc<steer_core::auth::ProviderRegistry>,
92 server_handle: Option<JoinHandle<Result<()>>>,
93 cleanup_handle: Option<JoinHandle<()>>,
94 shutdown_tx: Option<oneshot::Sender<()>>,
95 config: ServiceHostConfig,
96}
97
98impl ServiceHost {
99 pub async fn new(config: ServiceHostConfig) -> Result<Self> {
101 let store = create_session_store(&config.db_path).await?;
103
104 let model_registry = Arc::new(
106 steer_core::model_registry::ModelRegistry::load(&config.catalog_config.catalog_paths)
107 .map_err(|e| GrpcError::InvalidSessionState {
108 reason: format!("Failed to load model registry: {e}"),
109 })?,
110 );
111
112 let provider_registry = Arc::new(
114 steer_core::auth::ProviderRegistry::load(&config.catalog_config.catalog_paths)
115 .map_err(|e| GrpcError::InvalidSessionState {
116 reason: format!("Failed to load provider registry: {e}"),
117 })?,
118 );
119
120 let session_manager = Arc::new(SessionManager::new(
122 store,
123 config.session_manager_config.clone(),
124 ));
125
126 info!(
127 "ServiceHost initialized with database at {:?}",
128 config.db_path
129 );
130
131 Ok(Self {
132 session_manager,
133 model_registry,
134 provider_registry,
135 server_handle: None,
136 cleanup_handle: None,
137 shutdown_tx: None,
138 config,
139 })
140 }
141
142 pub async fn start(&mut self) -> Result<()> {
144 if self.server_handle.is_some() {
145 return Err(GrpcError::InvalidSessionState {
146 reason: "Server is already running".to_string(),
147 });
148 }
149
150 let llm_config_provider =
152 steer_core::config::LlmConfigProvider::new(self.config.auth_storage.clone());
153
154 let service = AgentServiceImpl::new(
155 self.session_manager.clone(),
156 llm_config_provider,
157 self.model_registry.clone(),
158 self.provider_registry.clone(),
159 );
160 let (shutdown_tx, shutdown_rx) = oneshot::channel();
161
162 let addr = self.config.bind_addr;
163
164 info!("Starting gRPC server on {}", addr);
165
166 let server_handle = tokio::spawn(async move {
167 Server::builder()
168 .add_service(AgentServiceServer::new(service))
169 .serve_with_shutdown(addr, async {
170 shutdown_rx.await.ok();
171 info!("gRPC server shutdown signal received");
172 })
173 .await
174 .map_err(GrpcError::ConnectionFailed)
175 });
176
177 let session_manager = self.session_manager.clone();
179 let cleanup_handle = tokio::spawn(async move {
180 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(300)); loop {
182 interval.tick().await;
183
184 let idle_duration = chrono::Duration::minutes(30);
186 match session_manager
187 .cleanup_inactive_sessions(idle_duration)
188 .await
189 {
190 0 => {} count => info!("Cleaned up {} inactive sessions", count),
192 }
193 }
194 });
195
196 self.server_handle = Some(server_handle);
197 self.cleanup_handle = Some(cleanup_handle);
198 self.shutdown_tx = Some(shutdown_tx);
199
200 info!("gRPC server listening on {}", addr);
201 Ok(())
202 }
203
204 pub async fn shutdown(mut self) -> Result<()> {
206 info!("Initiating ServiceHost shutdown");
207
208 if let Some(shutdown_tx) = self.shutdown_tx.take() {
210 let _ = shutdown_tx.send(());
211 }
212
213 if let Some(cleanup_handle) = self.cleanup_handle.take() {
215 cleanup_handle.abort();
216 }
217
218 if let Some(server_handle) = self.server_handle.take() {
220 match server_handle.await {
221 Ok(Ok(())) => info!("gRPC server shut down successfully"),
222 Ok(Err(e)) => error!("gRPC server error during shutdown: {}", e),
223 Err(e) => error!("Failed to join server task: {}", e),
224 }
225 }
226
227 let active_sessions = self.session_manager.get_active_sessions().await;
229 for session_id in active_sessions {
230 if let Err(e) = self.session_manager.suspend_session(&session_id).await {
231 error!(
232 "Failed to suspend session {} during shutdown: {}",
233 session_id, e
234 );
235 }
236 }
237
238 info!("ServiceHost shutdown complete");
239 Ok(())
240 }
241
242 pub fn session_manager(&self) -> &Arc<SessionManager> {
244 &self.session_manager
245 }
246
247 pub async fn wait(&mut self) -> Result<()> {
249 if let Some(server_handle) = &mut self.server_handle {
250 match server_handle.await {
251 Ok(result) => result,
252 Err(e) => Err(GrpcError::StreamError(format!("Server task panicked: {e}"))),
253 }
254 } else {
255 Err(GrpcError::InvalidSessionState {
256 reason: "Server is not running".to_string(),
257 })
258 }
259 }
260}
261
262async fn create_session_store(db_path: &std::path::Path) -> Result<Arc<dyn SessionStore>> {
264 use steer_core::session::SessionStoreConfig;
265 use steer_core::utils::session::create_session_store_with_config;
266
267 let config = SessionStoreConfig::sqlite(db_path.to_path_buf());
268 create_session_store_with_config(config)
269 .await
270 .map_err(|e| GrpcError::InvalidSessionState {
271 reason: format!("Failed to create session store: {e}"),
272 })
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 use tempfile::TempDir;
280
281 fn create_test_config() -> (ServiceHostConfig, TempDir) {
282 let temp_dir = TempDir::new().unwrap();
283 let db_path = temp_dir.path().join("test.db");
284
285 let config = ServiceHostConfig {
286 db_path,
287 session_manager_config: SessionManagerConfig {
288 max_concurrent_sessions: 10,
289 default_model: steer_core::config::model::builtin::claude_3_7_sonnet_20250219(),
290 auto_persist: true,
291 },
292 bind_addr: "127.0.0.1:0".parse().unwrap(), auth_storage: Arc::new(steer_core::test_utils::InMemoryAuthStorage::new()),
294 catalog_config: CatalogConfig::default(),
295 };
296
297 (config, temp_dir)
298 }
299
300 #[tokio::test]
301 async fn test_service_host_creation() {
302 let (config, _temp_dir) = create_test_config();
303
304 let host = ServiceHost::new(config).await.unwrap();
305
306 assert_eq!(host.session_manager().get_active_sessions().await.len(), 0);
308 }
309
310 #[tokio::test]
311 async fn test_service_host_lifecycle() {
312 let (mut config, _temp_dir) = create_test_config();
313 config.bind_addr = "127.0.0.1:0".parse().unwrap(); let mut host = ServiceHost::new(config).await.unwrap();
316
317 host.start().await.unwrap();
319
320 assert!(host.server_handle.is_some());
322
323 host.shutdown().await.unwrap();
325 }
326}