1use crate::grpc::conversions::message_to_proto;
2use crate::grpc::session_manager_ext::SessionManagerExt;
3use std::sync::Arc;
4use steer_core::session::manager::SessionManager;
5use steer_proto::agent::v1::{self as proto, *};
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::ReceiverStream;
8use tonic::{Request, Response, Status, Streaming};
9use tracing::{debug, error, info, warn};
10
11pub struct AgentServiceImpl {
12 session_manager: Arc<SessionManager>,
13 llm_config_provider: steer_core::config::LlmConfigProvider,
14}
15
16impl AgentServiceImpl {
17 pub fn new(
18 session_manager: Arc<SessionManager>,
19 llm_config_provider: steer_core::config::LlmConfigProvider,
20 ) -> Self {
21 Self {
22 session_manager,
23 llm_config_provider,
24 }
25 }
26}
27
28#[tonic::async_trait]
29impl agent_service_server::AgentService for AgentServiceImpl {
30 type StreamSessionStream = ReceiverStream<Result<StreamSessionResponse, Status>>;
31 type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
32 type GetSessionStream =
33 std::pin::Pin<Box<dyn futures::Stream<Item = Result<GetSessionResponse, Status>> + Send>>;
34 type GetConversationStream = std::pin::Pin<
35 Box<dyn futures::Stream<Item = Result<GetConversationResponse, Status>> + Send>,
36 >;
37 type ActivateSessionStream = std::pin::Pin<
38 Box<dyn futures::Stream<Item = Result<ActivateSessionResponse, Status>> + Send>,
39 >;
40
41 async fn stream_session(
42 &self,
43 request: Request<Streaming<StreamSessionRequest>>,
44 ) -> Result<Response<Self::StreamSessionStream>, Status> {
45 let mut client_stream = request.into_inner();
46 let (tx, rx) = mpsc::channel(100);
47
48 let session_manager = self.session_manager.clone();
50 let llm_config_provider = self.llm_config_provider.clone();
51
52 let _stream_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
53 let (session_id, mut event_rx) = if let Some(client_message_result) =
55 client_stream.message().await.transpose()
56 {
57 match client_message_result {
58 Ok(client_message) => {
59 let session_id = client_message.session_id.clone();
60
61 let receiver = match session_manager
63 .take_event_receiver(&client_message.session_id)
64 .await
65 {
66 Ok(receiver) => {
67 debug!("Session {} is already active, TUI should call GetConversation to retrieve history", session_id);
69 receiver
70 },
71 Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionNotActive { session_id })) => {
72 info!("Session {} not active, attempting to resume", session_id);
73
74 match try_resume_session(&session_manager, &session_id, &llm_config_provider).await {
76 Ok(()) => {
77 match session_manager.take_event_receiver(&session_id).await {
79 Ok(receiver) => receiver,
80 Err(e) => {
81 error!("Failed to get event receiver after resuming session {}: {}", session_id, e);
82 let _ = tx
83 .send(Err(Status::internal(format!(
84 "Failed to establish stream after resuming session: {e}"
85 ))))
86 .await;
87 return;
88 }
89 }
90 }
91 Err(e) => {
92 error!("Failed to resume session {}: {}", session_id, e);
93 let _ = tx
94 .send(Err(e))
95 .await;
96 return;
97 }
98 }
99 }
100 Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionAlreadyHasListener { session_id })) => {
101 error!("Session already has an active stream: {}", session_id);
102 let _ = tx
103 .send(Err(Status::already_exists(format!(
104 "Session {session_id} already has an active stream"
105 ))))
106 .await;
107 return;
108 }
109 Err(e) => {
110 error!("Error taking event receiver: {}", e);
111 let _ = tx
112 .send(Err(Status::internal(format!(
113 "Error establishing stream: {e}"
114 ))))
115 .await;
116 return;
117 }
118 };
119
120 if let Err(e) =
122 handle_client_message(&session_manager, client_message).await
123 {
124 error!("Error handling first client message: {}", e);
125 let _ = tx
126 .send(Err(Status::internal(format!(
127 "Error processing message: {e}"
128 ))))
129 .await;
130 return;
131 }
132
133 (session_id, receiver)
134 }
135 Err(e) => {
136 error!("Error receiving first client message: {}", e);
137 let _ = tx.send(Err(Status::internal("Stream error"))).await;
138 return;
139 }
140 }
141 } else {
142 error!("No initial client message received");
143 let _ = tx.send(Err(Status::internal("No initial message"))).await;
144 return;
145 };
146
147 let mut event_sequence = 0u64;
148
149 if let Err(e) = session_manager
151 .increment_subscriber_count(&session_id)
152 .await
153 {
154 warn!(
155 "Failed to increment subscriber count for session {}: {}",
156 session_id, e
157 );
158 }
159
160 let tx_clone = tx.clone();
162 let session_id_clone = session_id.clone();
163 let event_task = tokio::spawn(async move {
164 while let Some(app_event) = event_rx.recv().await {
165 event_sequence += 1;
166 let server_event = match crate::grpc::conversions::app_event_to_server_event(
167 app_event,
168 event_sequence,
169 ) {
170 Ok(event) => event,
171 Err(e) => {
172 warn!("Failed to convert app event to server event: {}", e);
173 continue;
174 }
175 };
176
177 if let Err(e) = tx_clone.send(Ok(server_event)).await {
178 warn!("Failed to send event to client: {}", e);
179 break;
180 }
181 }
182 debug!(
183 "Event forwarding task ended for session: {}",
184 session_id_clone
185 );
186 });
187
188 while let Some(client_message_result) = client_stream.message().await.transpose() {
190 match client_message_result {
191 Ok(client_message) => {
192 if let Err(e) = session_manager.touch_session(&session_id).await {
194 warn!("Failed to touch session {}: {}", session_id, e);
195 }
196
197 if let Err(e) =
198 handle_client_message(&session_manager, client_message).await
199 {
200 error!("Error handling client message: {}", e);
201 let _ = tx
202 .send(Err(Status::internal(format!(
203 "Error processing message: {e}"
204 ))))
205 .await;
206 break;
207 }
208 }
209 Err(e) => {
210 error!("Error receiving client message: {}", e);
211 let _ = tx.send(Err(Status::internal("Stream error"))).await;
212 break;
213 }
214 }
215 }
216
217 event_task.abort();
219
220 if let Err(e) = session_manager
222 .decrement_subscriber_count(&session_id)
223 .await
224 {
225 warn!(
226 "Failed to decrement subscriber count for session {}: {}",
227 session_id, e
228 );
229 }
230
231 info!("Client stream ended for session: {}", session_id);
232
233 if let Err(e) = session_manager
235 .maybe_suspend_idle_session(&session_id)
236 .await
237 {
238 warn!("Failed to check/suspend idle session {}: {}", session_id, e);
239 }
240 });
241
242 Ok(Response::new(ReceiverStream::new(rx)))
243 }
244
245 async fn create_session(
246 &self,
247 request: Request<CreateSessionRequest>,
248 ) -> Result<Response<CreateSessionResponse>, Status> {
249 let req = request.into_inner();
250
251 let app_config = steer_core::app::AppConfig {
252 llm_config_provider: self.llm_config_provider.clone(),
253 };
254
255 match self
256 .session_manager
257 .create_session_grpc(req, app_config)
258 .await
259 {
260 Ok((_session_id, session_info)) => Ok(Response::new(CreateSessionResponse {
261 session: Some(session_info),
262 })),
263 Err(e) => {
264 error!("Failed to create session: {}", e);
265 Err(e.into())
266 }
267 }
268 }
269
270 async fn list_sessions(
271 &self,
272 request: Request<ListSessionsRequest>,
273 ) -> Result<Response<ListSessionsResponse>, Status> {
274 let _req = request.into_inner();
275
276 let filter = steer_core::session::SessionFilter::default();
278
279 match self.session_manager.list_sessions(filter).await {
280 Ok(sessions) => {
281 let proto_sessions = sessions
282 .into_iter()
283 .map(|session| SessionInfo {
284 id: session.id,
285 created_at: Some(prost_types::Timestamp::from(
286 std::time::SystemTime::from(session.created_at),
287 )),
288 updated_at: Some(prost_types::Timestamp::from(
289 std::time::SystemTime::from(session.updated_at),
290 )),
291 status: proto::SessionStatus::Active as i32,
292 metadata: Some(proto::SessionMetadata {
293 labels: session.metadata,
294 annotations: std::collections::HashMap::new(),
295 }),
296 })
297 .collect();
298
299 Ok(Response::new(ListSessionsResponse {
300 sessions: proto_sessions,
301 next_page_token: None,
302 }))
303 }
304 Err(e) => {
305 error!("Failed to list sessions: {}", e);
306 Err(Status::internal(format!("Failed to list sessions: {e}")))
307 }
308 }
309 }
310
311 async fn get_session(
312 &self,
313 request: Request<GetSessionRequest>,
314 ) -> Result<Response<Self::GetSessionStream>, Status> {
315 let req = request.into_inner();
316 let session_manager = self.session_manager.clone();
317
318 let stream = async_stream::try_stream! {
319 match session_manager.get_session_proto(&req.session_id).await {
320 Ok(Some(session_state)) => {
321 yield GetSessionResponse {
323 chunk: Some(get_session_response::Chunk::Header(SessionStateHeader {
324 id: session_state.id,
325 created_at: session_state.created_at,
326 updated_at: session_state.updated_at,
327 config: session_state.config,
328 })),
329 };
330
331 for message in session_state.messages {
333 yield GetSessionResponse {
334 chunk: Some(get_session_response::Chunk::Message(message)),
335 };
336 }
337
338 for (key, value) in session_state.tool_calls {
340 yield GetSessionResponse {
341 chunk: Some(get_session_response::Chunk::ToolCall(ToolCallStateEntry {
342 key,
343 value: Some(value),
344 })),
345 };
346 }
347
348 yield GetSessionResponse {
350 chunk: Some(get_session_response::Chunk::Footer(SessionStateFooter {
351 approved_tools: session_state.approved_tools,
352 last_event_sequence: session_state.last_event_sequence,
353 metadata: session_state.metadata,
354 })),
355 };
356 }
357 Ok(None) => {
358 Err(Status::not_found(format!(
359 "Session not found: {}",
360 req.session_id
361 )))?;
362 }
363 Err(e) => {
364 error!("Failed to get session: {}", e);
365 Err(Status::internal(format!("Failed to get session: {e}")))?;
366 }
367 }
368 };
369
370 Ok(Response::new(Box::pin(stream)))
371 }
372
373 async fn delete_session(
374 &self,
375 request: Request<DeleteSessionRequest>,
376 ) -> Result<Response<DeleteSessionResponse>, Status> {
377 let req = request.into_inner();
378
379 match self.session_manager.delete_session(&req.session_id).await {
380 Ok(true) => Ok(Response::new(DeleteSessionResponse {})),
381 Ok(false) => Err(Status::not_found(format!(
382 "Session not found: {}",
383 req.session_id
384 ))),
385 Err(e) => {
386 error!("Failed to delete session: {}", e);
387 Err(Status::internal(format!("Failed to delete session: {e}")))
388 }
389 }
390 }
391
392 async fn get_conversation(
393 &self,
394 request: Request<GetConversationRequest>,
395 ) -> Result<Response<Self::GetConversationStream>, Status> {
396 let req = request.into_inner();
397 let session_manager = self.session_manager.clone();
398
399 info!("GetConversation called for session: {}", req.session_id);
400
401 let stream = async_stream::try_stream! {
402 match session_manager.get_session_state(&req.session_id).await {
403 Ok(Some(session_state)) => {
404 info!(
405 "Found session state with {} messages and {} approved tools",
406 session_state.messages.len(),
407 session_state.approved_tools.len()
408 );
409
410 for msg in session_state.messages {
412 let proto_msg = message_to_proto(msg.clone())
413 .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
414 yield GetConversationResponse {
415 chunk: Some(get_conversation_response::Chunk::Message(proto_msg)),
416 };
417 }
418
419 yield GetConversationResponse {
421 chunk: Some(get_conversation_response::Chunk::Footer(GetConversationFooter {
422 approved_tools: session_state.approved_tools.into_iter().collect(),
423 })),
424 };
425 }
426 Ok(None) => {
427 Err(Status::not_found(format!(
428 "Session not found: {}",
429 req.session_id
430 )))?;
431 }
432 Err(e) => {
433 error!("Failed to get session state: {}", e);
434 Err(Status::internal(format!("Failed to get session state: {e}")))?;
435 }
436 }
437 };
438
439 Ok(Response::new(Box::pin(stream)))
440 }
441
442 async fn send_message(
443 &self,
444 request: Request<SendMessageRequest>,
445 ) -> Result<Response<SendMessageResponse>, Status> {
446 let req = request.into_inner();
447
448 let app_command = steer_core::app::AppCommand::ProcessUserInput(req.message);
449
450 match self
451 .session_manager
452 .send_command(&req.session_id, app_command)
453 .await
454 {
455 Ok(()) => {
456 let operation_id = format!("op_{}", uuid::Uuid::new_v4());
458 Ok(Response::new(SendMessageResponse {
459 operation: Some(Operation {
460 id: operation_id,
461 session_id: req.session_id,
462 r#type: OperationType::SendMessage as i32,
463 status: OperationStatus::Running as i32,
464 created_at: Some(
465 prost_types::Timestamp::from(std::time::SystemTime::now()),
466 ),
467 completed_at: None,
468 metadata: std::collections::HashMap::new(),
469 }),
470 }))
471 }
472 Err(e) => {
473 error!("Failed to send message: {}", e);
474 Err(Status::internal(format!("Failed to send message: {e}")))
475 }
476 }
477 }
478
479 async fn approve_tool(
480 &self,
481 request: Request<ApproveToolRequest>,
482 ) -> Result<Response<ApproveToolResponse>, Status> {
483 let req = request.into_inner();
484
485 let approval = match req.decision {
486 Some(decision) => match decision {
487 proto::ApprovalDecision {
488 decision_type: Some(proto::approval_decision::DecisionType::Deny(true)),
489 } => steer_core::app::command::ApprovalType::Denied,
490 proto::ApprovalDecision {
491 decision_type: Some(proto::approval_decision::DecisionType::Once(true)),
492 } => steer_core::app::command::ApprovalType::Once,
493 proto::ApprovalDecision {
494 decision_type: Some(proto::approval_decision::DecisionType::AlwaysTool(true)),
495 } => steer_core::app::command::ApprovalType::AlwaysTool,
496 proto::ApprovalDecision {
497 decision_type:
498 Some(proto::approval_decision::DecisionType::AlwaysBashPattern(pattern)),
499 } => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
500 _ => {
501 return Err(Status::invalid_argument(
502 "Invalid approval decision enum value",
503 ));
504 }
505 },
506 None => {
507 return Err(Status::invalid_argument("Missing approval decision"));
508 }
509 };
510
511 let app_command = steer_core::app::AppCommand::HandleToolResponse {
512 id: req.tool_call_id,
513 approval,
514 };
515
516 match self
517 .session_manager
518 .send_command(&req.session_id, app_command)
519 .await
520 {
521 Ok(()) => Ok(Response::new(ApproveToolResponse {})),
522 Err(e) => {
523 error!("Failed to approve tool: {}", e);
524 Err(Status::internal(format!("Failed to approve tool: {e}")))
525 }
526 }
527 }
528
529 async fn activate_session(
530 &self,
531 request: Request<ActivateSessionRequest>,
532 ) -> Result<Response<Self::ActivateSessionStream>, Status> {
533 let req = request.into_inner();
534 let session_manager = self.session_manager.clone();
535 let llm_config_provider = self.llm_config_provider.clone();
536
537 info!("ActivateSession called for {}", req.session_id);
538
539 let stream = async_stream::try_stream! {
540 let state = if let Ok(Some(state)) = session_manager
542 .get_session_state(&req.session_id)
543 .await
544 {
545 state
546 } else {
547 let app_config = steer_core::app::AppConfig {
549 llm_config_provider: llm_config_provider.clone(),
550 };
551
552 session_manager
553 .resume_session(&req.session_id, app_config)
554 .await
555 .map_err(|e| Status::internal(format!("Failed to resume session: {e}")))?;
556
557 session_manager
559 .get_session_state(&req.session_id)
560 .await
561 .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?
562 .ok_or_else(|| Status::not_found(format!("Session not found: {}", req.session_id)))?
563 };
564
565 for msg in state.messages {
567 let proto_msg = message_to_proto(msg)
568 .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
569 yield ActivateSessionResponse {
570 chunk: Some(activate_session_response::Chunk::Message(proto_msg)),
571 };
572 }
573
574 yield ActivateSessionResponse {
576 chunk: Some(activate_session_response::Chunk::Footer(ActivateSessionFooter {
577 approved_tools: state.approved_tools.into_iter().collect(),
578 })),
579 };
580 };
581
582 Ok(Response::new(Box::pin(stream)))
583 }
584
585 async fn cancel_operation(
586 &self,
587 request: Request<CancelOperationRequest>,
588 ) -> Result<Response<CancelOperationResponse>, Status> {
589 let req = request.into_inner();
590
591 let app_command = steer_core::app::AppCommand::CancelProcessing;
592
593 match self
594 .session_manager
595 .send_command(&req.session_id, app_command)
596 .await
597 {
598 Ok(()) => Ok(Response::new(CancelOperationResponse {})),
599 Err(e) => {
600 error!("Failed to cancel operation: {}", e);
601 Err(Status::internal(format!("Failed to cancel operation: {e}")))
602 }
603 }
604 }
605
606 async fn list_files(
607 &self,
608 request: Request<ListFilesRequest>,
609 ) -> Result<Response<Self::ListFilesStream>, Status> {
610 let req = request.into_inner();
611
612 debug!("ListFiles called for session: {}", req.session_id);
613
614 let workspace = match self
616 .session_manager
617 .get_session_workspace(&req.session_id)
618 .await
619 {
620 Ok(Some(workspace)) => workspace,
621 Ok(None) => {
622 return Err(Status::not_found(format!(
623 "Session not found: {}",
624 req.session_id
625 )));
626 }
627 Err(e) => {
628 error!("Failed to get session workspace: {}", e);
629 return Err(Status::internal(format!(
630 "Failed to get session workspace: {e}"
631 )));
632 }
633 };
634
635 let (tx, rx) = mpsc::channel(100);
637
638 let _list_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
640 let query = if req.query.is_empty() {
642 None
643 } else {
644 Some(req.query.as_str())
645 };
646
647 let max_results = if req.max_results == 0 {
648 None
649 } else {
650 Some(req.max_results as usize)
651 };
652
653 match workspace.list_files(query, max_results).await {
654 Ok(files) => {
655 for chunk in files.chunks(1000) {
657 let response = ListFilesResponse {
658 paths: chunk.to_vec(),
659 };
660
661 if let Err(e) = tx.send(Ok(response)).await {
662 warn!("Failed to send file list chunk: {}", e);
663 break;
664 }
665 }
666 }
667 Err(e) => {
668 error!("Failed to list files: {}", e);
669 let _ = tx
670 .send(Err(Status::internal(format!("Failed to list files: {e}"))))
671 .await;
672 }
673 }
674 });
675
676 Ok(Response::new(ReceiverStream::new(rx)))
677 }
678
679 async fn get_mcp_servers(
680 &self,
681 request: Request<GetMcpServersRequest>,
682 ) -> Result<Response<GetMcpServersResponse>, Status> {
683 let req = request.into_inner();
684
685 debug!("GetMcpServers called for session: {}", req.session_id);
686
687 match self.session_manager.get_mcp_statuses(&req.session_id).await {
689 Ok(servers) => {
690 use crate::grpc::conversions::mcp_server_info_to_proto;
691
692 let proto_servers = servers.into_iter().map(mcp_server_info_to_proto).collect();
693
694 Ok(Response::new(GetMcpServersResponse {
695 servers: proto_servers,
696 }))
697 }
698 Err(e) => {
699 error!("Failed to get MCP server statuses: {}", e);
700 Err(Status::internal(format!(
701 "Failed to get MCP server statuses: {e}"
702 )))
703 }
704 }
705 }
706}
707
708async fn try_resume_session(
709 session_manager: &SessionManager,
710 session_id: &str,
711 llm_config_provider: &steer_core::config::LlmConfigProvider,
712) -> Result<(), Status> {
713 let app_config = steer_core::app::AppConfig {
714 llm_config_provider: llm_config_provider.clone(),
715 };
716
717 match session_manager.resume_session(session_id, app_config).await {
719 Ok(_command_tx) => {
720 info!("Successfully resumed session: {}", session_id);
721 Ok(())
723 }
724 Err(steer_core::error::Error::SessionManager(
725 steer_core::session::manager::SessionManagerError::CapacityExceeded { current, max },
726 )) => {
727 warn!(
728 "Cannot resume session {}: server at capacity ({}/{})",
729 session_id, current, max
730 );
731 Err(Status::resource_exhausted(format!(
732 "Server at maximum capacity ({current}/{max}). Cannot resume session."
733 )))
734 }
735 Err(e) => {
736 error!("Failed to resume session {}: {}", session_id, e);
737 Err(Status::internal(format!("Failed to resume session: {e}")))
738 }
739 }
740}
741
742async fn handle_client_message(
743 session_manager: &SessionManager,
744 client_message: StreamSessionRequest,
745) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
746 debug!(
747 "Handling client message for session: {}",
748 client_message.session_id
749 );
750
751 if let Some(message) = client_message.message {
752 match message {
753 stream_session_request::Message::SendMessage(send_msg) => {
754 let app_command = steer_core::app::AppCommand::ProcessUserInput(send_msg.message);
756
757 session_manager
758 .send_command(&client_message.session_id, app_command)
759 .await
760 .map_err(|e| format!("Failed to send message: {e}"))?;
761 }
762
763 stream_session_request::Message::ToolApproval(approval) => {
764 let approval_type = match approval.decision {
766 Some(decision) => match decision.decision_type {
767 Some(proto::approval_decision::DecisionType::Deny(_)) => {
768 steer_core::app::command::ApprovalType::Denied
769 }
770 Some(proto::approval_decision::DecisionType::Once(_)) => {
771 steer_core::app::command::ApprovalType::Once
772 }
773 Some(proto::approval_decision::DecisionType::AlwaysTool(_)) => {
774 steer_core::app::command::ApprovalType::AlwaysTool
775 }
776 Some(proto::approval_decision::DecisionType::AlwaysBashPattern(
777 pattern,
778 )) => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
779 None => {
780 return Err(
781 "Invalid approval decision: no decision type specified".into()
782 );
783 }
784 },
785 None => {
786 return Err("Invalid approval decision: no decision provided".into());
787 }
788 };
789
790 let app_command = steer_core::app::AppCommand::HandleToolResponse {
791 id: approval.tool_call_id,
792 approval: approval_type,
793 };
794
795 session_manager
796 .send_command(&client_message.session_id, app_command)
797 .await
798 .map_err(|e| format!("Failed to approve tool: {e}"))?;
799 }
800
801 stream_session_request::Message::Cancel(_cancel) => {
802 let app_command = steer_core::app::AppCommand::CancelProcessing;
804
805 session_manager
806 .send_command(&client_message.session_id, app_command)
807 .await
808 .map_err(|e| format!("Failed to cancel operation: {e}"))?;
809 }
810
811 stream_session_request::Message::Subscribe(_subscribe_request) => {
812 debug!("Subscribe message received - stream already established");
813 }
815
816 stream_session_request::Message::UpdateConfig(_update_config) => {
817 debug!("UpdateConfig received but provider changes are no longer supported");
820 }
821
822 stream_session_request::Message::ExecuteCommand(execute_command) => {
823 use steer_core::app::conversation::AppCommandType;
824 let app_cmd_type = match AppCommandType::parse(&execute_command.command) {
825 Ok(cmd) => cmd,
826 Err(e) => {
827 return Err(format!("Failed to parse command: {e}").into());
828 }
829 };
830 let app_command = steer_core::app::AppCommand::ExecuteCommand(app_cmd_type);
831 session_manager
832 .send_command(&client_message.session_id, app_command)
833 .await
834 .map_err(|e| format!("Failed to execute command: {e}"))?;
835 }
836
837 stream_session_request::Message::ExecuteBashCommand(execute_bash_command) => {
838 let app_command = steer_core::app::AppCommand::ExecuteBashCommand {
839 command: execute_bash_command.command,
840 };
841 session_manager
842 .send_command(&client_message.session_id, app_command)
843 .await
844 .map_err(|e| format!("Failed to execute bash command: {e}"))?;
845 }
846
847 stream_session_request::Message::EditMessage(edit_message) => {
848 let app_command = steer_core::app::AppCommand::EditMessage {
849 message_id: edit_message.message_id,
850 new_content: edit_message.new_content,
851 };
852 session_manager
853 .send_command(&client_message.session_id, app_command)
854 .await
855 .map_err(|e| format!("Failed to edit message: {e}"))?;
856 }
857 }
858 }
859
860 Ok(())
861}
862
863#[cfg(test)]
864mod tests {
865 use super::*;
866 use steer_core::api::Model;
867
868 use std::collections::HashMap;
869 use steer_core::session::state::WorkspaceConfig;
870 use steer_core::session::stores::sqlite::SqliteSessionStore;
871 use steer_core::session::{SessionConfig, SessionManagerConfig, SessionToolConfig};
872 use steer_proto::agent::v1::agent_service_client::AgentServiceClient;
873 use steer_proto::agent::v1::{SendMessageRequest, SubscribeRequest};
874 use tempfile::TempDir;
875 use tokio::sync::mpsc;
876 use tokio_stream::StreamExt;
877
878 fn create_test_app_config() -> steer_core::app::AppConfig {
879 steer_core::test_utils::test_app_config()
880 }
881
882 async fn create_test_session_manager() -> (Arc<SessionManager>, TempDir) {
883 let temp_dir = TempDir::new().unwrap();
884 let db_path = temp_dir.path().join("test.db");
885 let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
886
887 let config = SessionManagerConfig {
888 max_concurrent_sessions: 100,
889 default_model: Model::ClaudeSonnet4_20250514,
890 auto_persist: true,
891 };
892 let session_manager = Arc::new(SessionManager::new(store, config));
893
894 (session_manager, temp_dir)
895 }
896
897 async fn create_test_server() -> (String, Arc<SessionManager>, TempDir) {
898 let (session_manager, temp_dir) = create_test_session_manager().await;
899
900 let auth_storage = Arc::new(steer_core::test_utils::InMemoryAuthStorage::new());
901 let llm_config_provider = steer_core::config::LlmConfigProvider::new(auth_storage);
902 let service = AgentServiceImpl::new(session_manager.clone(), llm_config_provider);
903
904 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
906 let addr = listener.local_addr().unwrap();
907
908 let _server_task = tokio::spawn(async move {
909 tonic::transport::Server::builder()
910 .add_service(agent_service_server::AgentServiceServer::new(service))
911 .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
912 .await
913 .unwrap();
914 });
915
916 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
918
919 let url = format!("http://{addr}");
920 (url, session_manager, temp_dir)
921 }
922
923 #[tokio::test]
924 async fn test_session_cleanup_on_disconnect() {
925 let (session_manager, _temp_dir) = create_test_session_manager().await;
926
927 let session_config = SessionConfig {
929 workspace: WorkspaceConfig::Local {
930 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
931 },
932 tool_config: SessionToolConfig::default(),
933 system_prompt: None,
934 metadata: HashMap::new(),
935 };
936
937 let app_config = create_test_app_config();
938
939 let (session_id, _command_tx) = session_manager
940 .create_session(session_config, app_config)
941 .await
942 .unwrap();
943
944 assert!(session_manager.is_session_active(&session_id).await);
946
947 session_manager
949 .increment_subscriber_count(&session_id)
950 .await
951 .unwrap();
952
953 assert!(session_manager.is_session_active(&session_id).await);
955
956 session_manager
958 .decrement_subscriber_count(&session_id)
959 .await
960 .unwrap();
961
962 session_manager
964 .maybe_suspend_idle_session(&session_id)
965 .await
966 .unwrap();
967
968 assert!(
970 !session_manager.is_session_active(&session_id).await,
971 "Session should be suspended after last client disconnects"
972 );
973
974 let session_info = session_manager.get_session(&session_id).await.unwrap();
976 assert!(
977 session_info.is_some(),
978 "Session should still exist in storage after suspension"
979 );
980 }
981
982 #[tokio::test]
983 async fn test_session_with_multiple_subscribers() {
984 let (session_manager, _temp_dir) = create_test_session_manager().await;
985
986 let session_config = SessionConfig {
988 workspace: WorkspaceConfig::Local {
989 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
990 },
991 tool_config: SessionToolConfig::default(),
992 system_prompt: None,
993 metadata: HashMap::new(),
994 };
995
996 let app_config = create_test_app_config();
997
998 let (session_id, _command_tx) = session_manager
999 .create_session(session_config, app_config)
1000 .await
1001 .unwrap();
1002
1003 session_manager
1005 .increment_subscriber_count(&session_id)
1006 .await
1007 .unwrap();
1008 session_manager
1009 .increment_subscriber_count(&session_id)
1010 .await
1011 .unwrap();
1012
1013 session_manager
1015 .decrement_subscriber_count(&session_id)
1016 .await
1017 .unwrap();
1018 session_manager
1019 .maybe_suspend_idle_session(&session_id)
1020 .await
1021 .unwrap();
1022
1023 assert!(
1025 session_manager.is_session_active(&session_id).await,
1026 "Session should remain active with one subscriber"
1027 );
1028
1029 session_manager
1031 .decrement_subscriber_count(&session_id)
1032 .await
1033 .unwrap();
1034 session_manager
1035 .maybe_suspend_idle_session(&session_id)
1036 .await
1037 .unwrap();
1038
1039 assert!(
1041 !session_manager.is_session_active(&session_id).await,
1042 "Session should be suspended after all clients disconnect"
1043 );
1044 }
1045
1046 #[tokio::test]
1047 async fn test_grpc_client_connect_disconnect_cleanup() {
1048 let (server_url, session_manager, _temp_dir) = create_test_server().await;
1049
1050 let session_config = SessionConfig {
1052 workspace: WorkspaceConfig::Local {
1053 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1054 },
1055 tool_config: SessionToolConfig::default(),
1056 system_prompt: None,
1057 metadata: HashMap::new(),
1058 };
1059
1060 let app_config = create_test_app_config();
1061
1062 let (session_id, _command_tx) = session_manager
1063 .create_session(session_config, app_config)
1064 .await
1065 .unwrap();
1066
1067 assert!(session_manager.is_session_active(&session_id).await);
1069
1070 let mut client = AgentServiceClient::connect(server_url.clone())
1072 .await
1073 .unwrap();
1074
1075 let request_stream = tokio_stream::iter(vec![StreamSessionRequest {
1077 session_id: session_id.clone(),
1078 message: Some(stream_session_request::Message::Subscribe(
1079 SubscribeRequest {
1080 event_types: vec![],
1081 since_sequence: None,
1082 },
1083 )),
1084 }]);
1085
1086 let response = client.stream_session(request_stream).await.unwrap();
1087 let _stream = response.into_inner();
1088
1089 let (msg_tx, msg_rx) = mpsc::channel(10);
1091 msg_tx
1092 .send(StreamSessionRequest {
1093 session_id: session_id.clone(),
1094 message: Some(stream_session_request::Message::SendMessage(
1095 SendMessageRequest {
1096 session_id: session_id.clone(),
1097 message: "Hello, test!".to_string(),
1098 attachments: vec![],
1099 },
1100 )),
1101 })
1102 .await
1103 .unwrap();
1104
1105 let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1107 let response = client.stream_session(request_stream).await.unwrap();
1108 let mut stream = response.into_inner();
1109
1110 let timeout =
1112 tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next()).await;
1113
1114 assert!(timeout.is_ok(), "Should receive at least one event");
1115
1116 drop(stream);
1118 drop(msg_tx);
1119
1120 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1122
1123 assert!(
1125 !session_manager.is_session_active(&session_id).await,
1126 "Session should be suspended after client disconnect"
1127 );
1128
1129 let session_info = session_manager.get_session(&session_id).await.unwrap();
1131 assert!(
1132 session_info.is_some(),
1133 "Session should still exist in storage"
1134 );
1135 }
1136
1137 #[tokio::test]
1138 async fn test_grpc_basic_session_resume() {
1139 let (server_url, session_manager, _temp_dir) = create_test_server().await;
1140
1141 let session_config = SessionConfig {
1143 workspace: WorkspaceConfig::Local {
1144 path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1145 },
1146 tool_config: SessionToolConfig::default(),
1147 system_prompt: None,
1148 metadata: HashMap::new(),
1149 };
1150
1151 let app_config = create_test_app_config();
1152
1153 let (session_id, _command_tx) = session_manager
1154 .create_session(session_config, app_config)
1155 .await
1156 .unwrap();
1157
1158 session_manager.suspend_session(&session_id).await.unwrap();
1160 assert!(
1161 !session_manager.is_session_active(&session_id).await,
1162 "Session should be suspended"
1163 );
1164
1165 let mut client = AgentServiceClient::connect(server_url.clone())
1167 .await
1168 .unwrap();
1169
1170 let (msg_tx, msg_rx) = mpsc::channel(10);
1172
1173 msg_tx
1175 .send(StreamSessionRequest {
1176 session_id: session_id.clone(),
1177 message: Some(stream_session_request::Message::Subscribe(
1178 SubscribeRequest {
1179 event_types: vec![],
1180 since_sequence: None,
1181 },
1182 )),
1183 })
1184 .await
1185 .unwrap();
1186
1187 let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1188 let response = client.stream_session(request_stream).await;
1189
1190 assert!(
1192 response.is_ok(),
1193 "Should be able to connect to suspended session (auto-resume)"
1194 );
1195
1196 let stream = response.unwrap().into_inner();
1197
1198 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1200
1201 assert!(
1203 session_manager.is_session_active(&session_id).await,
1204 "Session should be active after auto-resume"
1205 );
1206
1207 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1209 assert!(
1210 session_manager.is_session_active(&session_id).await,
1211 "Session should remain active while client is connected"
1212 );
1213
1214 drop(stream);
1216 drop(msg_tx);
1217
1218 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1220
1221 assert!(
1223 !session_manager.is_session_active(&session_id).await,
1224 "Session should be suspended after client disconnects"
1225 );
1226 }
1227}