steer_grpc/grpc/
server.rs

1use crate::grpc::conversions::message_to_proto;
2use crate::grpc::session_manager_ext::SessionManagerExt;
3use std::sync::Arc;
4use steer_core::session::manager::SessionManager;
5use steer_proto::agent::v1::{self as proto, *};
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::ReceiverStream;
8use tonic::{Request, Response, Status, Streaming};
9use tracing::{debug, error, info, warn};
10
11pub struct AgentServiceImpl {
12    session_manager: Arc<SessionManager>,
13    llm_config_provider: steer_core::config::LlmConfigProvider,
14}
15
16impl AgentServiceImpl {
17    pub fn new(
18        session_manager: Arc<SessionManager>,
19        llm_config_provider: steer_core::config::LlmConfigProvider,
20    ) -> Self {
21        Self {
22            session_manager,
23            llm_config_provider,
24        }
25    }
26}
27
28#[tonic::async_trait]
29impl agent_service_server::AgentService for AgentServiceImpl {
30    type StreamSessionStream = ReceiverStream<Result<StreamSessionResponse, Status>>;
31    type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
32    type GetSessionStream =
33        std::pin::Pin<Box<dyn futures::Stream<Item = Result<GetSessionResponse, Status>> + Send>>;
34    type GetConversationStream = std::pin::Pin<
35        Box<dyn futures::Stream<Item = Result<GetConversationResponse, Status>> + Send>,
36    >;
37    type ActivateSessionStream = std::pin::Pin<
38        Box<dyn futures::Stream<Item = Result<ActivateSessionResponse, Status>> + Send>,
39    >;
40
41    async fn stream_session(
42        &self,
43        request: Request<Streaming<StreamSessionRequest>>,
44    ) -> Result<Response<Self::StreamSessionStream>, Status> {
45        let mut client_stream = request.into_inner();
46        let (tx, rx) = mpsc::channel(100);
47
48        // Clone session manager and llm_config_provider for the stream handler task
49        let session_manager = self.session_manager.clone();
50        let llm_config_provider = self.llm_config_provider.clone();
51
52        let _stream_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
53            // Handle the first message to establish the session connection
54            let (session_id, mut event_rx) = if let Some(client_message_result) =
55                client_stream.message().await.transpose()
56            {
57                match client_message_result {
58                    Ok(client_message) => {
59                        let session_id = client_message.session_id.clone();
60
61                        // Try to take the event receiver for this session
62                        let receiver = match session_manager
63                            .take_event_receiver(&client_message.session_id)
64                            .await
65                        {
66                            Ok(receiver) => {
67                                // Session is already active - TUI will call GetConversation RPC to get history
68                                debug!("Session {} is already active, TUI should call GetConversation to retrieve history", session_id);
69                                receiver
70                            },
71                            Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionNotActive { session_id })) => {
72                                info!("Session {} not active, attempting to resume", session_id);
73
74                                // Try to resume the session
75                                match try_resume_session(&session_manager, &session_id, &llm_config_provider).await {
76                                    Ok(()) => {
77                                        // Session resumed, try to take receiver again
78                                        match session_manager.take_event_receiver(&session_id).await {
79                                            Ok(receiver) => receiver,
80                                            Err(e) => {
81                                                error!("Failed to get event receiver after resuming session {}: {}", session_id, e);
82                                                let _ = tx
83                                                    .send(Err(Status::internal(format!(
84                                                        "Failed to establish stream after resuming session: {e}"
85                                                    ))))
86                                                    .await;
87                                                return;
88                                            }
89                                        }
90                                    }
91                                    Err(e) => {
92                                        error!("Failed to resume session {}: {}", session_id, e);
93                                        let _ = tx
94                                            .send(Err(e))
95                                            .await;
96                                        return;
97                                    }
98                                }
99                            }
100                            Err(steer_core::error::Error::SessionManager(steer_core::session::manager::SessionManagerError::SessionAlreadyHasListener { session_id })) => {
101                                error!("Session already has an active stream: {}", session_id);
102                                let _ = tx
103                                    .send(Err(Status::already_exists(format!(
104                                        "Session {session_id} already has an active stream"
105                                    ))))
106                                    .await;
107                                return;
108                            }
109                            Err(e) => {
110                                error!("Error taking event receiver: {}", e);
111                                let _ = tx
112                                    .send(Err(Status::internal(format!(
113                                        "Error establishing stream: {e}"
114                                    ))))
115                                    .await;
116                                return;
117                            }
118                        };
119
120                        // Process the first message
121                        if let Err(e) =
122                            handle_client_message(&session_manager, client_message).await
123                        {
124                            error!("Error handling first client message: {}", e);
125                            let _ = tx
126                                .send(Err(Status::internal(format!(
127                                    "Error processing message: {e}"
128                                ))))
129                                .await;
130                            return;
131                        }
132
133                        (session_id, receiver)
134                    }
135                    Err(e) => {
136                        error!("Error receiving first client message: {}", e);
137                        let _ = tx.send(Err(Status::internal("Stream error"))).await;
138                        return;
139                    }
140                }
141            } else {
142                error!("No initial client message received");
143                let _ = tx.send(Err(Status::internal("No initial message"))).await;
144                return;
145            };
146
147            let mut event_sequence = 0u64;
148
149            // Mark session as having an active subscriber
150            if let Err(e) = session_manager
151                .increment_subscriber_count(&session_id)
152                .await
153            {
154                warn!(
155                    "Failed to increment subscriber count for session {}: {}",
156                    session_id, e
157                );
158            }
159
160            // Spawn task to handle outgoing events (App -> Client)
161            let tx_clone = tx.clone();
162            let session_id_clone = session_id.clone();
163            let event_task = tokio::spawn(async move {
164                while let Some(app_event) = event_rx.recv().await {
165                    event_sequence += 1;
166                    let server_event = match crate::grpc::conversions::app_event_to_server_event(
167                        app_event,
168                        event_sequence,
169                    ) {
170                        Ok(event) => event,
171                        Err(e) => {
172                            warn!("Failed to convert app event to server event: {}", e);
173                            continue;
174                        }
175                    };
176
177                    if let Err(e) = tx_clone.send(Ok(server_event)).await {
178                        warn!("Failed to send event to client: {}", e);
179                        break;
180                    }
181                }
182                debug!(
183                    "Event forwarding task ended for session: {}",
184                    session_id_clone
185                );
186            });
187
188            // Handle incoming messages (Client -> App)
189            while let Some(client_message_result) = client_stream.message().await.transpose() {
190                match client_message_result {
191                    Ok(client_message) => {
192                        // Touch the session to update last activity
193                        if let Err(e) = session_manager.touch_session(&session_id).await {
194                            warn!("Failed to touch session {}: {}", session_id, e);
195                        }
196
197                        if let Err(e) =
198                            handle_client_message(&session_manager, client_message).await
199                        {
200                            error!("Error handling client message: {}", e);
201                            let _ = tx
202                                .send(Err(Status::internal(format!(
203                                    "Error processing message: {e}"
204                                ))))
205                                .await;
206                            break;
207                        }
208                    }
209                    Err(e) => {
210                        error!("Error receiving client message: {}", e);
211                        let _ = tx.send(Err(Status::internal("Stream error"))).await;
212                        break;
213                    }
214                }
215            }
216
217            // Clean up
218            event_task.abort();
219
220            // Decrement subscriber count
221            if let Err(e) = session_manager
222                .decrement_subscriber_count(&session_id)
223                .await
224            {
225                warn!(
226                    "Failed to decrement subscriber count for session {}: {}",
227                    session_id, e
228                );
229            }
230
231            info!("Client stream ended for session: {}", session_id);
232
233            // Check if we should suspend the session (no more subscribers)
234            if let Err(e) = session_manager
235                .maybe_suspend_idle_session(&session_id)
236                .await
237            {
238                warn!("Failed to check/suspend idle session {}: {}", session_id, e);
239            }
240        });
241
242        Ok(Response::new(ReceiverStream::new(rx)))
243    }
244
245    async fn create_session(
246        &self,
247        request: Request<CreateSessionRequest>,
248    ) -> Result<Response<CreateSessionResponse>, Status> {
249        let req = request.into_inner();
250
251        let app_config = steer_core::app::AppConfig {
252            llm_config_provider: self.llm_config_provider.clone(),
253        };
254
255        match self
256            .session_manager
257            .create_session_grpc(req, app_config)
258            .await
259        {
260            Ok((_session_id, session_info)) => Ok(Response::new(CreateSessionResponse {
261                session: Some(session_info),
262            })),
263            Err(e) => {
264                error!("Failed to create session: {}", e);
265                Err(e.into())
266            }
267        }
268    }
269
270    async fn list_sessions(
271        &self,
272        request: Request<ListSessionsRequest>,
273    ) -> Result<Response<ListSessionsResponse>, Status> {
274        let _req = request.into_inner();
275
276        // Create filter - for now just list all sessions
277        let filter = steer_core::session::SessionFilter::default();
278
279        match self.session_manager.list_sessions(filter).await {
280            Ok(sessions) => {
281                let proto_sessions = sessions
282                    .into_iter()
283                    .map(|session| SessionInfo {
284                        id: session.id,
285                        created_at: Some(prost_types::Timestamp::from(
286                            std::time::SystemTime::from(session.created_at),
287                        )),
288                        updated_at: Some(prost_types::Timestamp::from(
289                            std::time::SystemTime::from(session.updated_at),
290                        )),
291                        status: proto::SessionStatus::Active as i32,
292                        metadata: Some(proto::SessionMetadata {
293                            labels: session.metadata,
294                            annotations: std::collections::HashMap::new(),
295                        }),
296                    })
297                    .collect();
298
299                Ok(Response::new(ListSessionsResponse {
300                    sessions: proto_sessions,
301                    next_page_token: None,
302                }))
303            }
304            Err(e) => {
305                error!("Failed to list sessions: {}", e);
306                Err(Status::internal(format!("Failed to list sessions: {e}")))
307            }
308        }
309    }
310
311    async fn get_session(
312        &self,
313        request: Request<GetSessionRequest>,
314    ) -> Result<Response<Self::GetSessionStream>, Status> {
315        let req = request.into_inner();
316        let session_manager = self.session_manager.clone();
317
318        let stream = async_stream::try_stream! {
319            match session_manager.get_session_proto(&req.session_id).await {
320                Ok(Some(session_state)) => {
321                    // Send header
322                    yield GetSessionResponse {
323                        chunk: Some(get_session_response::Chunk::Header(SessionStateHeader {
324                            id: session_state.id,
325                            created_at: session_state.created_at,
326                            updated_at: session_state.updated_at,
327                            config: session_state.config,
328                        })),
329                    };
330
331                    // Stream messages one by one
332                    for message in session_state.messages {
333                        yield GetSessionResponse {
334                            chunk: Some(get_session_response::Chunk::Message(message)),
335                        };
336                    }
337
338                    // Stream tool calls
339                    for (key, value) in session_state.tool_calls {
340                        yield GetSessionResponse {
341                            chunk: Some(get_session_response::Chunk::ToolCall(ToolCallStateEntry {
342                                key,
343                                value: Some(value),
344                            })),
345                        };
346                    }
347
348                    // Send footer
349                    yield GetSessionResponse {
350                        chunk: Some(get_session_response::Chunk::Footer(SessionStateFooter {
351                            approved_tools: session_state.approved_tools,
352                            last_event_sequence: session_state.last_event_sequence,
353                            metadata: session_state.metadata,
354                        })),
355                    };
356                }
357                Ok(None) => {
358                    Err(Status::not_found(format!(
359                        "Session not found: {}",
360                        req.session_id
361                    )))?;
362                }
363                Err(e) => {
364                    error!("Failed to get session: {}", e);
365                    Err(Status::internal(format!("Failed to get session: {e}")))?;
366                }
367            }
368        };
369
370        Ok(Response::new(Box::pin(stream)))
371    }
372
373    async fn delete_session(
374        &self,
375        request: Request<DeleteSessionRequest>,
376    ) -> Result<Response<DeleteSessionResponse>, Status> {
377        let req = request.into_inner();
378
379        match self.session_manager.delete_session(&req.session_id).await {
380            Ok(true) => Ok(Response::new(DeleteSessionResponse {})),
381            Ok(false) => Err(Status::not_found(format!(
382                "Session not found: {}",
383                req.session_id
384            ))),
385            Err(e) => {
386                error!("Failed to delete session: {}", e);
387                Err(Status::internal(format!("Failed to delete session: {e}")))
388            }
389        }
390    }
391
392    async fn get_conversation(
393        &self,
394        request: Request<GetConversationRequest>,
395    ) -> Result<Response<Self::GetConversationStream>, Status> {
396        let req = request.into_inner();
397        let session_manager = self.session_manager.clone();
398
399        info!("GetConversation called for session: {}", req.session_id);
400
401        let stream = async_stream::try_stream! {
402            match session_manager.get_session_state(&req.session_id).await {
403                Ok(Some(session_state)) => {
404                    info!(
405                        "Found session state with {} messages and {} approved tools",
406                        session_state.messages.len(),
407                        session_state.approved_tools.len()
408                    );
409
410                    // Stream messages one by one
411                    for msg in session_state.messages {
412                        let proto_msg = message_to_proto(msg.clone())
413                            .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
414                        yield GetConversationResponse {
415                            chunk: Some(get_conversation_response::Chunk::Message(proto_msg)),
416                        };
417                    }
418
419                    // Send footer with approved tools
420                    yield GetConversationResponse {
421                        chunk: Some(get_conversation_response::Chunk::Footer(GetConversationFooter {
422                            approved_tools: session_state.approved_tools.into_iter().collect(),
423                        })),
424                    };
425                }
426                Ok(None) => {
427                    Err(Status::not_found(format!(
428                        "Session not found: {}",
429                        req.session_id
430                    )))?;
431                }
432                Err(e) => {
433                    error!("Failed to get session state: {}", e);
434                    Err(Status::internal(format!("Failed to get session state: {e}")))?;
435                }
436            }
437        };
438
439        Ok(Response::new(Box::pin(stream)))
440    }
441
442    async fn send_message(
443        &self,
444        request: Request<SendMessageRequest>,
445    ) -> Result<Response<SendMessageResponse>, Status> {
446        let req = request.into_inner();
447
448        let app_command = steer_core::app::AppCommand::ProcessUserInput(req.message);
449
450        match self
451            .session_manager
452            .send_command(&req.session_id, app_command)
453            .await
454        {
455            Ok(()) => {
456                // Generate operation ID for tracking
457                let operation_id = format!("op_{}", uuid::Uuid::new_v4());
458                Ok(Response::new(SendMessageResponse {
459                    operation: Some(Operation {
460                        id: operation_id,
461                        session_id: req.session_id,
462                        r#type: OperationType::SendMessage as i32,
463                        status: OperationStatus::Running as i32,
464                        created_at: Some(
465                            prost_types::Timestamp::from(std::time::SystemTime::now()),
466                        ),
467                        completed_at: None,
468                        metadata: std::collections::HashMap::new(),
469                    }),
470                }))
471            }
472            Err(e) => {
473                error!("Failed to send message: {}", e);
474                Err(Status::internal(format!("Failed to send message: {e}")))
475            }
476        }
477    }
478
479    async fn approve_tool(
480        &self,
481        request: Request<ApproveToolRequest>,
482    ) -> Result<Response<ApproveToolResponse>, Status> {
483        let req = request.into_inner();
484
485        let approval = match req.decision {
486            Some(decision) => match decision {
487                proto::ApprovalDecision {
488                    decision_type: Some(proto::approval_decision::DecisionType::Deny(true)),
489                } => steer_core::app::command::ApprovalType::Denied,
490                proto::ApprovalDecision {
491                    decision_type: Some(proto::approval_decision::DecisionType::Once(true)),
492                } => steer_core::app::command::ApprovalType::Once,
493                proto::ApprovalDecision {
494                    decision_type: Some(proto::approval_decision::DecisionType::AlwaysTool(true)),
495                } => steer_core::app::command::ApprovalType::AlwaysTool,
496                proto::ApprovalDecision {
497                    decision_type:
498                        Some(proto::approval_decision::DecisionType::AlwaysBashPattern(pattern)),
499                } => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
500                _ => {
501                    return Err(Status::invalid_argument(
502                        "Invalid approval decision enum value",
503                    ));
504                }
505            },
506            None => {
507                return Err(Status::invalid_argument("Missing approval decision"));
508            }
509        };
510
511        let app_command = steer_core::app::AppCommand::HandleToolResponse {
512            id: req.tool_call_id,
513            approval,
514        };
515
516        match self
517            .session_manager
518            .send_command(&req.session_id, app_command)
519            .await
520        {
521            Ok(()) => Ok(Response::new(ApproveToolResponse {})),
522            Err(e) => {
523                error!("Failed to approve tool: {}", e);
524                Err(Status::internal(format!("Failed to approve tool: {e}")))
525            }
526        }
527    }
528
529    async fn activate_session(
530        &self,
531        request: Request<ActivateSessionRequest>,
532    ) -> Result<Response<Self::ActivateSessionStream>, Status> {
533        let req = request.into_inner();
534        let session_manager = self.session_manager.clone();
535        let llm_config_provider = self.llm_config_provider.clone();
536
537        info!("ActivateSession called for {}", req.session_id);
538
539        let stream = async_stream::try_stream! {
540            // Check if already active or activate it
541            let state = if let Ok(Some(state)) = session_manager
542                .get_session_state(&req.session_id)
543                .await
544            {
545                state
546            } else {
547                // Not active, so activate it
548                let app_config = steer_core::app::AppConfig {
549                    llm_config_provider: llm_config_provider.clone(),
550                };
551
552                session_manager
553                    .resume_session(&req.session_id, app_config)
554                    .await
555                    .map_err(|e| Status::internal(format!("Failed to resume session: {e}")))?;
556
557                // Fetch state now that it's active
558                session_manager
559                    .get_session_state(&req.session_id)
560                    .await
561                    .map_err(|e| Status::internal(format!("Failed to get session state: {e}")))?
562                    .ok_or_else(|| Status::not_found(format!("Session not found: {}", req.session_id)))?
563            };
564
565            // Stream messages one by one
566            for msg in state.messages {
567                let proto_msg = message_to_proto(msg)
568                    .map_err(|e| Status::internal(format!("Failed to convert message: {e}")))?;
569                yield ActivateSessionResponse {
570                    chunk: Some(activate_session_response::Chunk::Message(proto_msg)),
571                };
572            }
573
574            // Send footer with approved tools
575            yield ActivateSessionResponse {
576                chunk: Some(activate_session_response::Chunk::Footer(ActivateSessionFooter {
577                    approved_tools: state.approved_tools.into_iter().collect(),
578                })),
579            };
580        };
581
582        Ok(Response::new(Box::pin(stream)))
583    }
584
585    async fn cancel_operation(
586        &self,
587        request: Request<CancelOperationRequest>,
588    ) -> Result<Response<CancelOperationResponse>, Status> {
589        let req = request.into_inner();
590
591        let app_command = steer_core::app::AppCommand::CancelProcessing;
592
593        match self
594            .session_manager
595            .send_command(&req.session_id, app_command)
596            .await
597        {
598            Ok(()) => Ok(Response::new(CancelOperationResponse {})),
599            Err(e) => {
600                error!("Failed to cancel operation: {}", e);
601                Err(Status::internal(format!("Failed to cancel operation: {e}")))
602            }
603        }
604    }
605
606    async fn list_files(
607        &self,
608        request: Request<ListFilesRequest>,
609    ) -> Result<Response<Self::ListFilesStream>, Status> {
610        let req = request.into_inner();
611
612        debug!("ListFiles called for session: {}", req.session_id);
613
614        // Get the session's workspace
615        let workspace = match self
616            .session_manager
617            .get_session_workspace(&req.session_id)
618            .await
619        {
620            Ok(Some(workspace)) => workspace,
621            Ok(None) => {
622                return Err(Status::not_found(format!(
623                    "Session not found: {}",
624                    req.session_id
625                )));
626            }
627            Err(e) => {
628                error!("Failed to get session workspace: {}", e);
629                return Err(Status::internal(format!(
630                    "Failed to get session workspace: {e}"
631                )));
632            }
633        };
634
635        // Create the response stream
636        let (tx, rx) = mpsc::channel(100);
637
638        // Spawn task to stream the files
639        let _list_task: tokio::task::JoinHandle<()> = tokio::spawn(async move {
640            // Get the file list from the workspace
641            let query = if req.query.is_empty() {
642                None
643            } else {
644                Some(req.query.as_str())
645            };
646
647            let max_results = if req.max_results == 0 {
648                None
649            } else {
650                Some(req.max_results as usize)
651            };
652
653            match workspace.list_files(query, max_results).await {
654                Ok(files) => {
655                    // Stream files in chunks of 1000
656                    for chunk in files.chunks(1000) {
657                        let response = ListFilesResponse {
658                            paths: chunk.to_vec(),
659                        };
660
661                        if let Err(e) = tx.send(Ok(response)).await {
662                            warn!("Failed to send file list chunk: {}", e);
663                            break;
664                        }
665                    }
666                }
667                Err(e) => {
668                    error!("Failed to list files: {}", e);
669                    let _ = tx
670                        .send(Err(Status::internal(format!("Failed to list files: {e}"))))
671                        .await;
672                }
673            }
674        });
675
676        Ok(Response::new(ReceiverStream::new(rx)))
677    }
678
679    async fn get_mcp_servers(
680        &self,
681        request: Request<GetMcpServersRequest>,
682    ) -> Result<Response<GetMcpServersResponse>, Status> {
683        let req = request.into_inner();
684
685        debug!("GetMcpServers called for session: {}", req.session_id);
686
687        // Get MCP server statuses from session manager
688        match self.session_manager.get_mcp_statuses(&req.session_id).await {
689            Ok(servers) => {
690                use crate::grpc::conversions::mcp_server_info_to_proto;
691
692                let proto_servers = servers.into_iter().map(mcp_server_info_to_proto).collect();
693
694                Ok(Response::new(GetMcpServersResponse {
695                    servers: proto_servers,
696                }))
697            }
698            Err(e) => {
699                error!("Failed to get MCP server statuses: {}", e);
700                Err(Status::internal(format!(
701                    "Failed to get MCP server statuses: {e}"
702                )))
703            }
704        }
705    }
706}
707
708async fn try_resume_session(
709    session_manager: &SessionManager,
710    session_id: &str,
711    llm_config_provider: &steer_core::config::LlmConfigProvider,
712) -> Result<(), Status> {
713    let app_config = steer_core::app::AppConfig {
714        llm_config_provider: llm_config_provider.clone(),
715    };
716
717    // Attempt to resume the session
718    match session_manager.resume_session(session_id, app_config).await {
719        Ok(_command_tx) => {
720            info!("Successfully resumed session: {}", session_id);
721            // TUI will call GetCurrentConversation when it connects
722            Ok(())
723        }
724        Err(steer_core::error::Error::SessionManager(
725            steer_core::session::manager::SessionManagerError::CapacityExceeded { current, max },
726        )) => {
727            warn!(
728                "Cannot resume session {}: server at capacity ({}/{})",
729                session_id, current, max
730            );
731            Err(Status::resource_exhausted(format!(
732                "Server at maximum capacity ({current}/{max}). Cannot resume session."
733            )))
734        }
735        Err(e) => {
736            error!("Failed to resume session {}: {}", session_id, e);
737            Err(Status::internal(format!("Failed to resume session: {e}")))
738        }
739    }
740}
741
742async fn handle_client_message(
743    session_manager: &SessionManager,
744    client_message: StreamSessionRequest,
745) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
746    debug!(
747        "Handling client message for session: {}",
748        client_message.session_id
749    );
750
751    if let Some(message) = client_message.message {
752        match message {
753            stream_session_request::Message::SendMessage(send_msg) => {
754                // Convert to AppCommand - just process user input since that's what exists
755                let app_command = steer_core::app::AppCommand::ProcessUserInput(send_msg.message);
756
757                session_manager
758                    .send_command(&client_message.session_id, app_command)
759                    .await
760                    .map_err(|e| format!("Failed to send message: {e}"))?;
761            }
762
763            stream_session_request::Message::ToolApproval(approval) => {
764                // Convert approval decision using existing HandleToolResponse
765                let approval_type = match approval.decision {
766                    Some(decision) => match decision.decision_type {
767                        Some(proto::approval_decision::DecisionType::Deny(_)) => {
768                            steer_core::app::command::ApprovalType::Denied
769                        }
770                        Some(proto::approval_decision::DecisionType::Once(_)) => {
771                            steer_core::app::command::ApprovalType::Once
772                        }
773                        Some(proto::approval_decision::DecisionType::AlwaysTool(_)) => {
774                            steer_core::app::command::ApprovalType::AlwaysTool
775                        }
776                        Some(proto::approval_decision::DecisionType::AlwaysBashPattern(
777                            pattern,
778                        )) => steer_core::app::command::ApprovalType::AlwaysBashPattern(pattern),
779                        None => {
780                            return Err(
781                                "Invalid approval decision: no decision type specified".into()
782                            );
783                        }
784                    },
785                    None => {
786                        return Err("Invalid approval decision: no decision provided".into());
787                    }
788                };
789
790                let app_command = steer_core::app::AppCommand::HandleToolResponse {
791                    id: approval.tool_call_id,
792                    approval: approval_type,
793                };
794
795                session_manager
796                    .send_command(&client_message.session_id, app_command)
797                    .await
798                    .map_err(|e| format!("Failed to approve tool: {e}"))?;
799            }
800
801            stream_session_request::Message::Cancel(_cancel) => {
802                // Use existing CancelProcessing command
803                let app_command = steer_core::app::AppCommand::CancelProcessing;
804
805                session_manager
806                    .send_command(&client_message.session_id, app_command)
807                    .await
808                    .map_err(|e| format!("Failed to cancel operation: {e}"))?;
809            }
810
811            stream_session_request::Message::Subscribe(_subscribe_request) => {
812                debug!("Subscribe message received - stream already established");
813                // No action needed - stream is already active
814            }
815
816            stream_session_request::Message::UpdateConfig(_update_config) => {
817                // UpdateConfig no longer supports changing the LLM provider
818                // Tool config updates are handled separately
819                debug!("UpdateConfig received but provider changes are no longer supported");
820            }
821
822            stream_session_request::Message::ExecuteCommand(execute_command) => {
823                use steer_core::app::conversation::AppCommandType;
824                let app_cmd_type = match AppCommandType::parse(&execute_command.command) {
825                    Ok(cmd) => cmd,
826                    Err(e) => {
827                        return Err(format!("Failed to parse command: {e}").into());
828                    }
829                };
830                let app_command = steer_core::app::AppCommand::ExecuteCommand(app_cmd_type);
831                session_manager
832                    .send_command(&client_message.session_id, app_command)
833                    .await
834                    .map_err(|e| format!("Failed to execute command: {e}"))?;
835            }
836
837            stream_session_request::Message::ExecuteBashCommand(execute_bash_command) => {
838                let app_command = steer_core::app::AppCommand::ExecuteBashCommand {
839                    command: execute_bash_command.command,
840                };
841                session_manager
842                    .send_command(&client_message.session_id, app_command)
843                    .await
844                    .map_err(|e| format!("Failed to execute bash command: {e}"))?;
845            }
846
847            stream_session_request::Message::EditMessage(edit_message) => {
848                let app_command = steer_core::app::AppCommand::EditMessage {
849                    message_id: edit_message.message_id,
850                    new_content: edit_message.new_content,
851                };
852                session_manager
853                    .send_command(&client_message.session_id, app_command)
854                    .await
855                    .map_err(|e| format!("Failed to edit message: {e}"))?;
856            }
857        }
858    }
859
860    Ok(())
861}
862
863#[cfg(test)]
864mod tests {
865    use super::*;
866    use steer_core::api::Model;
867
868    use std::collections::HashMap;
869    use steer_core::session::state::WorkspaceConfig;
870    use steer_core::session::stores::sqlite::SqliteSessionStore;
871    use steer_core::session::{SessionConfig, SessionManagerConfig, SessionToolConfig};
872    use steer_proto::agent::v1::agent_service_client::AgentServiceClient;
873    use steer_proto::agent::v1::{SendMessageRequest, SubscribeRequest};
874    use tempfile::TempDir;
875    use tokio::sync::mpsc;
876    use tokio_stream::StreamExt;
877
878    fn create_test_app_config() -> steer_core::app::AppConfig {
879        steer_core::test_utils::test_app_config()
880    }
881
882    async fn create_test_session_manager() -> (Arc<SessionManager>, TempDir) {
883        let temp_dir = TempDir::new().unwrap();
884        let db_path = temp_dir.path().join("test.db");
885        let store = Arc::new(SqliteSessionStore::new(&db_path).await.unwrap());
886
887        let config = SessionManagerConfig {
888            max_concurrent_sessions: 100,
889            default_model: Model::ClaudeSonnet4_20250514,
890            auto_persist: true,
891        };
892        let session_manager = Arc::new(SessionManager::new(store, config));
893
894        (session_manager, temp_dir)
895    }
896
897    async fn create_test_server() -> (String, Arc<SessionManager>, TempDir) {
898        let (session_manager, temp_dir) = create_test_session_manager().await;
899
900        let auth_storage = Arc::new(steer_core::test_utils::InMemoryAuthStorage::new());
901        let llm_config_provider = steer_core::config::LlmConfigProvider::new(auth_storage);
902        let service = AgentServiceImpl::new(session_manager.clone(), llm_config_provider);
903
904        // Start server on random port
905        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
906        let addr = listener.local_addr().unwrap();
907
908        let _server_task = tokio::spawn(async move {
909            tonic::transport::Server::builder()
910                .add_service(agent_service_server::AgentServiceServer::new(service))
911                .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
912                .await
913                .unwrap();
914        });
915
916        // Give server time to start
917        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
918
919        let url = format!("http://{addr}");
920        (url, session_manager, temp_dir)
921    }
922
923    #[tokio::test]
924    async fn test_session_cleanup_on_disconnect() {
925        let (session_manager, _temp_dir) = create_test_session_manager().await;
926
927        // Create a session
928        let session_config = SessionConfig {
929            workspace: WorkspaceConfig::Local {
930                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
931            },
932            tool_config: SessionToolConfig::default(),
933            system_prompt: None,
934            metadata: HashMap::new(),
935        };
936
937        let app_config = create_test_app_config();
938
939        let (session_id, _command_tx) = session_manager
940            .create_session(session_config, app_config)
941            .await
942            .unwrap();
943
944        // Verify session is active
945        assert!(session_manager.is_session_active(&session_id).await);
946
947        // Simulate a client connection by incrementing subscriber count
948        session_manager
949            .increment_subscriber_count(&session_id)
950            .await
951            .unwrap();
952
953        // Verify session is still active
954        assert!(session_manager.is_session_active(&session_id).await);
955
956        // Simulate client disconnect by decrementing subscriber count
957        session_manager
958            .decrement_subscriber_count(&session_id)
959            .await
960            .unwrap();
961
962        // Check if session should be suspended
963        session_manager
964            .maybe_suspend_idle_session(&session_id)
965            .await
966            .unwrap();
967
968        // Verify session was suspended (not active in memory)
969        assert!(
970            !session_manager.is_session_active(&session_id).await,
971            "Session should be suspended after last client disconnects"
972        );
973
974        // Verify session still exists in storage
975        let session_info = session_manager.get_session(&session_id).await.unwrap();
976        assert!(
977            session_info.is_some(),
978            "Session should still exist in storage after suspension"
979        );
980    }
981
982    #[tokio::test]
983    async fn test_session_with_multiple_subscribers() {
984        let (session_manager, _temp_dir) = create_test_session_manager().await;
985
986        // Create a session
987        let session_config = SessionConfig {
988            workspace: WorkspaceConfig::Local {
989                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
990            },
991            tool_config: SessionToolConfig::default(),
992            system_prompt: None,
993            metadata: HashMap::new(),
994        };
995
996        let app_config = create_test_app_config();
997
998        let (session_id, _command_tx) = session_manager
999            .create_session(session_config, app_config)
1000            .await
1001            .unwrap();
1002
1003        // Simulate two clients connecting
1004        session_manager
1005            .increment_subscriber_count(&session_id)
1006            .await
1007            .unwrap();
1008        session_manager
1009            .increment_subscriber_count(&session_id)
1010            .await
1011            .unwrap();
1012
1013        // First client disconnects
1014        session_manager
1015            .decrement_subscriber_count(&session_id)
1016            .await
1017            .unwrap();
1018        session_manager
1019            .maybe_suspend_idle_session(&session_id)
1020            .await
1021            .unwrap();
1022
1023        // Session should still be active (one subscriber remaining)
1024        assert!(
1025            session_manager.is_session_active(&session_id).await,
1026            "Session should remain active with one subscriber"
1027        );
1028
1029        // Second client disconnects
1030        session_manager
1031            .decrement_subscriber_count(&session_id)
1032            .await
1033            .unwrap();
1034        session_manager
1035            .maybe_suspend_idle_session(&session_id)
1036            .await
1037            .unwrap();
1038
1039        // Now session should be suspended
1040        assert!(
1041            !session_manager.is_session_active(&session_id).await,
1042            "Session should be suspended after all clients disconnect"
1043        );
1044    }
1045
1046    #[tokio::test]
1047    async fn test_grpc_client_connect_disconnect_cleanup() {
1048        let (server_url, session_manager, _temp_dir) = create_test_server().await;
1049
1050        // Create a session first
1051        let session_config = SessionConfig {
1052            workspace: WorkspaceConfig::Local {
1053                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1054            },
1055            tool_config: SessionToolConfig::default(),
1056            system_prompt: None,
1057            metadata: HashMap::new(),
1058        };
1059
1060        let app_config = create_test_app_config();
1061
1062        let (session_id, _command_tx) = session_manager
1063            .create_session(session_config, app_config)
1064            .await
1065            .unwrap();
1066
1067        // Verify session is active
1068        assert!(session_manager.is_session_active(&session_id).await);
1069
1070        // Connect client
1071        let mut client = AgentServiceClient::connect(server_url.clone())
1072            .await
1073            .unwrap();
1074
1075        // Start streaming with subscribe message
1076        let request_stream = tokio_stream::iter(vec![StreamSessionRequest {
1077            session_id: session_id.clone(),
1078            message: Some(stream_session_request::Message::Subscribe(
1079                SubscribeRequest {
1080                    event_types: vec![],
1081                    since_sequence: None,
1082                },
1083            )),
1084        }]);
1085
1086        let response = client.stream_session(request_stream).await.unwrap();
1087        let _stream = response.into_inner();
1088
1089        // Send a test message to verify session is working
1090        let (msg_tx, msg_rx) = mpsc::channel(10);
1091        msg_tx
1092            .send(StreamSessionRequest {
1093                session_id: session_id.clone(),
1094                message: Some(stream_session_request::Message::SendMessage(
1095                    SendMessageRequest {
1096                        session_id: session_id.clone(),
1097                        message: "Hello, test!".to_string(),
1098                        attachments: vec![],
1099                    },
1100                )),
1101            })
1102            .await
1103            .unwrap();
1104
1105        // Create new request stream with the message channel
1106        let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1107        let response = client.stream_session(request_stream).await.unwrap();
1108        let mut stream = response.into_inner();
1109
1110        // Wait for some response to verify session is working
1111        let timeout =
1112            tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next()).await;
1113
1114        assert!(timeout.is_ok(), "Should receive at least one event");
1115
1116        // Drop the stream to simulate client disconnect
1117        drop(stream);
1118        drop(msg_tx);
1119
1120        // Give the server time to process the disconnect
1121        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1122
1123        // Verify session was suspended (not active in memory)
1124        assert!(
1125            !session_manager.is_session_active(&session_id).await,
1126            "Session should be suspended after client disconnect"
1127        );
1128
1129        // Verify session still exists in storage
1130        let session_info = session_manager.get_session(&session_id).await.unwrap();
1131        assert!(
1132            session_info.is_some(),
1133            "Session should still exist in storage"
1134        );
1135    }
1136
1137    #[tokio::test]
1138    async fn test_grpc_basic_session_resume() {
1139        let (server_url, session_manager, _temp_dir) = create_test_server().await;
1140
1141        // Create a session
1142        let session_config = SessionConfig {
1143            workspace: WorkspaceConfig::Local {
1144                path: std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
1145            },
1146            tool_config: SessionToolConfig::default(),
1147            system_prompt: None,
1148            metadata: HashMap::new(),
1149        };
1150
1151        let app_config = create_test_app_config();
1152
1153        let (session_id, _command_tx) = session_manager
1154            .create_session(session_config, app_config)
1155            .await
1156            .unwrap();
1157
1158        // Suspend the session manually to simulate a disconnected state
1159        session_manager.suspend_session(&session_id).await.unwrap();
1160        assert!(
1161            !session_manager.is_session_active(&session_id).await,
1162            "Session should be suspended"
1163        );
1164
1165        // Try to reconnect - this should auto-resume the session
1166        let mut client = AgentServiceClient::connect(server_url.clone())
1167            .await
1168            .unwrap();
1169
1170        // Use a channel to keep the stream alive
1171        let (msg_tx, msg_rx) = mpsc::channel(10);
1172
1173        // Send initial subscribe message
1174        msg_tx
1175            .send(StreamSessionRequest {
1176                session_id: session_id.clone(),
1177                message: Some(stream_session_request::Message::Subscribe(
1178                    SubscribeRequest {
1179                        event_types: vec![],
1180                        since_sequence: None,
1181                    },
1182                )),
1183            })
1184            .await
1185            .unwrap();
1186
1187        let request_stream = tokio_stream::wrappers::ReceiverStream::new(msg_rx);
1188        let response = client.stream_session(request_stream).await;
1189
1190        // The connection should succeed (auto-resume should work)
1191        assert!(
1192            response.is_ok(),
1193            "Should be able to connect to suspended session (auto-resume)"
1194        );
1195
1196        let stream = response.unwrap().into_inner();
1197
1198        // Give time for auto-resume to complete
1199        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1200
1201        // Session should be active again after auto-resume
1202        assert!(
1203            session_manager.is_session_active(&session_id).await,
1204            "Session should be active after auto-resume"
1205        );
1206
1207        // Keep the stream alive a bit longer to ensure it stays active
1208        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1209        assert!(
1210            session_manager.is_session_active(&session_id).await,
1211            "Session should remain active while client is connected"
1212        );
1213
1214        // Clean up - drop the stream to disconnect
1215        drop(stream);
1216        drop(msg_tx);
1217
1218        // Give time for cleanup to run
1219        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
1220
1221        // Now session should be suspended again after disconnect
1222        assert!(
1223            !session_manager.is_session_active(&session_id).await,
1224            "Session should be suspended after client disconnects"
1225        );
1226    }
1227}