Skip to main content

enact_grpc/
lib.rs

1pub mod proto {
2    pub mod runtime {
3        pub mod v1 {
4            tonic::include_proto!("enact.runtime.v1");
5        }
6    }
7}
8
9use enact_cluster::{
10    load_cluster_config, DispatchStrategy, DispatchTask, Dispatcher, RedisNodeRegistry,
11    RedisTaskQueue, TaskId,
12};
13use enact_core::callable::{CallableRegistry, LlmCallable};
14use enact_core::graph::{Checkpoint, CheckpointStore, InMemoryCheckpointStore};
15use enact_core::inbox::{
16    ControlAction, ControlMessage, GuidanceMessage, GuidancePriority, InMemoryInboxStore,
17    InboxMessage, InboxStore,
18};
19use enact_core::kernel::ExecutionId;
20use enact_core::kernel::SpawnMode;
21use enact_core::providers::ModelProvider;
22use enact_providers::factory::create_default_provider;
23use proto::runtime::v1::*;
24use std::pin::Pin;
25use std::sync::Arc;
26use tokio_stream::{wrappers::ReceiverStream, Stream};
27use tonic::{Request, Response, Status};
28use tracing::{error, info};
29
30pub use proto::runtime::v1::runtime_service_server::{RuntimeService, RuntimeServiceServer};
31
32#[derive(Clone)]
33pub struct GrpcState {
34    pub registry: Arc<CallableRegistry>,
35    pub provider: Arc<dyn ModelProvider>,
36    pub inbox_store: Arc<dyn InboxStore>,
37    pub checkpoint_store: Arc<dyn CheckpointStore>,
38}
39
40impl Default for GrpcState {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl GrpcState {
47    fn create_provider() -> Arc<dyn ModelProvider> {
48        create_default_provider().unwrap_or_else(|e| {
49            panic!(
50                "Failed to create provider from config: {}. Set default_model_id in ~/.enact/config.yaml, or add providers.yaml with models and required env vars.",
51                e
52            )
53        })
54    }
55
56    /// Create state with provider from config (default_model_id in config.yaml).
57    pub fn new() -> Self {
58        let provider = Self::create_provider();
59
60        let registry = Arc::new(CallableRegistry::new());
61        register_default_callables(&registry, provider.clone());
62
63        let inbox_store = InMemoryInboxStore::shared();
64        let checkpoint_store = Arc::new(InMemoryCheckpointStore::new()) as Arc<dyn CheckpointStore>;
65
66        Self {
67            registry,
68            provider,
69            inbox_store,
70            checkpoint_store,
71        }
72    }
73
74    /// Create GrpcState with a custom inbox store (provider from config).
75    pub fn with_inbox_store(inbox_store: Arc<dyn InboxStore>) -> Self {
76        let provider = Self::create_provider();
77
78        let registry = Arc::new(CallableRegistry::new());
79        register_default_callables(&registry, provider.clone());
80
81        let checkpoint_store = Arc::new(InMemoryCheckpointStore::new()) as Arc<dyn CheckpointStore>;
82
83        Self {
84            registry,
85            provider,
86            inbox_store,
87            checkpoint_store,
88        }
89    }
90
91    /// Create GrpcState with custom inbox and checkpoint stores (provider from config).
92    pub fn with_stores(
93        inbox_store: Arc<dyn InboxStore>,
94        checkpoint_store: Arc<dyn CheckpointStore>,
95    ) -> Self {
96        let provider = Self::create_provider();
97
98        let registry = Arc::new(CallableRegistry::new());
99        register_default_callables(&registry, provider.clone());
100
101        Self {
102            registry,
103            provider,
104            inbox_store,
105            checkpoint_store,
106        }
107    }
108
109    /// Create GrpcState with an explicit provider (e.g. for tests without config).
110    pub fn with_provider(
111        provider: Arc<dyn ModelProvider>,
112        inbox_store: Arc<dyn InboxStore>,
113    ) -> Self {
114        let registry = Arc::new(CallableRegistry::new());
115        register_default_callables(&registry, provider.clone());
116        let checkpoint_store = Arc::new(InMemoryCheckpointStore::new()) as Arc<dyn CheckpointStore>;
117        Self {
118            registry,
119            provider,
120            inbox_store,
121            checkpoint_store,
122        }
123    }
124}
125
126fn register_default_callables(registry: &CallableRegistry, provider: Arc<dyn ModelProvider>) {
127    let assistant = LlmCallable::with_provider(
128        "assistant".to_string(),
129        "You are a helpful assistant.".to_string(),
130        provider.clone(),
131    );
132    registry.register("assistant".to_string(), Arc::new(assistant));
133
134    let coder = LlmCallable::with_provider(
135        "coder".to_string(),
136        "You are an expert programmer. Help with coding tasks, debugging, and code review."
137            .to_string(),
138        provider.clone(),
139    );
140    registry.register("coder".to_string(), Arc::new(coder));
141
142    info!("Registered {} callables for gRPC", registry.len());
143}
144
145async fn try_dispatch_background_task(req: &RunAgentRequest) -> anyhow::Result<Option<TaskId>> {
146    if !req.background {
147        return Ok(None);
148    }
149
150    let cluster = load_cluster_config()?;
151    let node_registry = Arc::new(RedisNodeRegistry::new(
152        &cluster.redis_url,
153        (cluster.worker.heartbeat_interval_secs.max(1)) * 3,
154    )?);
155    let task_queue = Arc::new(RedisTaskQueue::new(
156        &cluster.redis_url,
157        cluster.dispatcher.task_stream.clone(),
158    )?);
159    let dispatcher = Dispatcher::new(
160        match cluster.dispatcher.strategy {
161            DispatchStrategy::LeastLoaded => DispatchStrategy::LeastLoaded,
162            DispatchStrategy::RoundRobin => DispatchStrategy::RoundRobin,
163            DispatchStrategy::CapabilityMatch => DispatchStrategy::CapabilityMatch,
164        },
165        node_registry,
166        task_queue,
167    );
168
169    let context_json = serde_json::to_value(&req.context).unwrap_or_else(|_| serde_json::json!({}));
170    let spawn_mode = SpawnMode::Child {
171        background: req.background,
172        inherit_inbox: req.inherit_inbox,
173        policies: None,
174    };
175
176    let task = DispatchTask {
177        task_id: TaskId::new(),
178        agent_name: req.agent_name.clone(),
179        input: serde_json::json!({
180            "input": req.input,
181            "context": context_json,
182            "parent_execution_id": req.parent_execution_id,
183            "checkpoint_id": req.checkpoint_id,
184        }),
185        spawn_mode,
186        priority: 5,
187    };
188
189    let task_id = dispatcher.dispatch(task, None).await?;
190    Ok(Some(task_id))
191}
192
193pub struct RuntimeServiceImpl {
194    state: Arc<GrpcState>,
195}
196
197impl RuntimeServiceImpl {
198    pub fn new(state: Arc<GrpcState>) -> Self {
199        Self { state }
200    }
201}
202
203#[tonic::async_trait]
204impl RuntimeService for RuntimeServiceImpl {
205    async fn run_agent(
206        &self,
207        request: Request<RunAgentRequest>,
208    ) -> Result<Response<RunAgentResponse>, Status> {
209        let req = request.into_inner();
210        let run_id = uuid::Uuid::new_v4().to_string();
211
212        info!("RunAgent: {} (run_id: {})", req.agent_name, run_id);
213
214        let callable =
215            self.state.registry.get(&req.agent_name).ok_or_else(|| {
216                Status::not_found(format!("Agent '{}' not found", req.agent_name))
217            })?;
218
219        if let Ok(Some(task_id)) = try_dispatch_background_task(&req).await {
220            return Ok(Response::new(RunAgentResponse {
221                run_id,
222                output: format!("Queued background task {}", task_id),
223                iterations: 0,
224                tool_calls: vec![],
225            }));
226        }
227
228        let input = if !req.context.is_empty() {
229            let context_str = req
230                .context
231                .iter()
232                .map(|(k, v)| format!("{}: {}", k, v))
233                .collect::<Vec<_>>()
234                .join("\n");
235            format!("Context:\n{}\n\nTask: {}", context_str, req.input)
236        } else {
237            req.input
238        };
239
240        match callable.run(&input).await {
241            Ok(output) => {
242                let response = RunAgentResponse {
243                    run_id,
244                    output,
245                    iterations: 1,
246                    tool_calls: vec![],
247                };
248                Ok(Response::new(response))
249            }
250            Err(e) => {
251                error!("Agent execution failed: {}", e);
252                Err(Status::internal(format!("Execution failed: {}", e)))
253            }
254        }
255    }
256
257    type RunAgentStreamStream = Pin<Box<dyn Stream<Item = Result<StreamEvent, Status>> + Send>>;
258
259    type ResumeFromCheckpointStream =
260        Pin<Box<dyn Stream<Item = Result<StreamEvent, Status>> + Send>>;
261
262    async fn run_agent_stream(
263        &self,
264        request: Request<RunAgentRequest>,
265    ) -> Result<Response<Self::RunAgentStreamStream>, Status> {
266        let req = request.into_inner();
267        let run_id = uuid::Uuid::new_v4().to_string();
268
269        info!("RunAgentStream: {} (run_id: {})", req.agent_name, run_id);
270
271        let callable =
272            self.state.registry.get(&req.agent_name).ok_or_else(|| {
273                Status::not_found(format!("Agent '{}' not found", req.agent_name))
274            })?;
275
276        if let Ok(Some(task_id)) = try_dispatch_background_task(&req).await {
277            let (tx, rx) = tokio::sync::mpsc::channel(8);
278            let run_id_clone = run_id.clone();
279            let _ = tx
280                .send(Ok(StreamEvent {
281                    run_id: run_id_clone.clone(),
282                    event: Some(stream_event::Event::RunStarted(RunStarted {})),
283                }))
284                .await;
285            let _ = tx
286                .send(Ok(StreamEvent {
287                    run_id: run_id_clone,
288                    event: Some(stream_event::Event::RunCompleted(RunCompleted {
289                        final_output: format!("Queued background task {}", task_id),
290                    })),
291                }))
292                .await;
293            let stream = ReceiverStream::new(rx);
294            return Ok(Response::new(Box::pin(stream)));
295        }
296
297        let input = if !req.context.is_empty() {
298            let context_str = req
299                .context
300                .iter()
301                .map(|(k, v)| format!("{}: {}", k, v))
302                .collect::<Vec<_>>()
303                .join("\n");
304            format!("Context:\n{}\n\nTask: {}", context_str, req.input)
305        } else {
306            req.input
307        };
308
309        let (tx, rx) = tokio::sync::mpsc::channel(100);
310        let run_id_clone = run_id.clone();
311        let agent_name = req.agent_name.clone();
312        let checkpoint_store = self.state.checkpoint_store.clone();
313        let execution_id = ExecutionId::from_string(&run_id);
314
315        tokio::spawn(async move {
316            let _ = tx
317                .send(Ok(StreamEvent {
318                    run_id: run_id_clone.clone(),
319                    event: Some(stream_event::Event::RunStarted(RunStarted {})),
320                }))
321                .await;
322
323            match callable.run(&input).await {
324                Ok(output) => {
325                    let checkpoint =
326                        Checkpoint::new(execution_id.clone()).with_agent_name(&agent_name);
327                    if let Err(e) = checkpoint_store.save(checkpoint).await {
328                        error!("Failed to save checkpoint: {}", e);
329                    }
330                    let _ = tx
331                        .send(Ok(StreamEvent {
332                            run_id: run_id_clone,
333                            event: Some(stream_event::Event::RunCompleted(RunCompleted {
334                                final_output: output,
335                            })),
336                        }))
337                        .await;
338                }
339                Err(e) => {
340                    let _ = tx
341                        .send(Ok(StreamEvent {
342                            run_id: run_id_clone,
343                            event: Some(stream_event::Event::Error(ErrorEvent {
344                                message: e.to_string(),
345                                code: "EXECUTION_ERROR".to_string(),
346                            })),
347                        }))
348                        .await;
349                }
350            }
351        });
352
353        let stream = ReceiverStream::new(rx);
354        Ok(Response::new(Box::pin(stream)))
355    }
356
357    async fn resume_from_checkpoint(
358        &self,
359        request: Request<ResumeRequest>,
360    ) -> Result<Response<Self::ResumeFromCheckpointStream>, Status> {
361        let req = request.into_inner();
362        let checkpoint_id = req.checkpoint_id;
363
364        info!("ResumeFromCheckpoint: {}", checkpoint_id);
365
366        // Load checkpoint from store
367        let checkpoint = self
368            .state
369            .checkpoint_store
370            .load(&checkpoint_id)
371            .await
372            .map_err(|e| Status::internal(format!("Failed to load checkpoint: {}", e)))?
373            .ok_or_else(|| {
374                Status::not_found(format!("Checkpoint '{}' not found", checkpoint_id))
375            })?;
376
377        // Get the run_id from checkpoint to find the original agent
378        let run_id = checkpoint.run_id.clone();
379
380        // Build input from checkpoint state + optional new input
381        let input = if let Some(new_input) = req.new_input {
382            // If new input provided, append to checkpoint messages
383            format!(
384                "Previous context from checkpoint {}:\n{}\n\nNew input: {}",
385                checkpoint_id,
386                checkpoint
387                    .messages
388                    .iter()
389                    .map(|m| format!("{}: {}", m.role, m.content))
390                    .collect::<Vec<_>>()
391                    .join("\n"),
392                new_input
393            )
394        } else {
395            // Resume with checkpoint state
396            format!(
397                "Resuming from checkpoint {}. Previous conversation:\n{}",
398                checkpoint_id,
399                checkpoint
400                    .messages
401                    .iter()
402                    .map(|m| format!("{}: {}", m.role, m.content))
403                    .collect::<Vec<_>>()
404                    .join("\n")
405            )
406        };
407
408        // Retrieve agent name from checkpoint metadata, defaulting to "assistant" for
409        // checkpoints created before agent_name was stored in metadata
410        let agent_name = checkpoint.agent_name().unwrap_or("assistant");
411
412        let callable = self
413            .state
414            .registry
415            .get(agent_name)
416            .ok_or_else(|| Status::not_found(format!("Agent '{}' not found", agent_name)))?;
417
418        let (tx, rx) = tokio::sync::mpsc::channel(32);
419        let callable_clone = callable.clone();
420        let run_id_clone = run_id.as_str().to_string();
421
422        tokio::spawn(async move {
423            // Send run started event
424            let _ = tx
425                .send(Ok(StreamEvent {
426                    run_id: run_id_clone.clone(),
427                    event: Some(stream_event::Event::RunStarted(RunStarted {})),
428                }))
429                .await;
430
431            // Execute the resumed callable
432            match callable_clone.run(&input).await {
433                Ok(output) => {
434                    // Stream the output as content delta
435                    let _ = tx
436                        .send(Ok(StreamEvent {
437                            run_id: run_id_clone.clone(),
438                            event: Some(stream_event::Event::ContentDelta(ContentDelta {
439                                content: output.clone(),
440                            })),
441                        }))
442                        .await;
443
444                    // Send run completed
445                    let _ = tx
446                        .send(Ok(StreamEvent {
447                            run_id: run_id_clone,
448                            event: Some(stream_event::Event::RunCompleted(RunCompleted {
449                                final_output: output,
450                            })),
451                        }))
452                        .await;
453                }
454                Err(e) => {
455                    let _ = tx
456                        .send(Ok(StreamEvent {
457                            run_id: run_id_clone,
458                            event: Some(stream_event::Event::Error(ErrorEvent {
459                                message: e.to_string(),
460                                code: "EXECUTION_ERROR".to_string(),
461                            })),
462                        }))
463                        .await;
464                }
465            }
466        });
467
468        let stream = ReceiverStream::new(rx);
469        Ok(Response::new(Box::pin(stream)))
470    }
471
472    async fn get_agent_info(
473        &self,
474        request: Request<GetAgentInfoRequest>,
475    ) -> Result<Response<AgentInfo>, Status> {
476        let req = request.into_inner();
477
478        let _callable =
479            self.state.registry.get(&req.agent_name).ok_or_else(|| {
480                Status::not_found(format!("Agent '{}' not found", req.agent_name))
481            })?;
482
483        let info = AgentInfo {
484            name: req.agent_name,
485            description: "Enact agent".to_string(),
486            tools: vec![],
487        };
488
489        Ok(Response::new(info))
490    }
491
492    async fn list_agents(
493        &self,
494        _request: Request<ListAgentsRequest>,
495    ) -> Result<Response<ListAgentsResponse>, Status> {
496        let names = self.state.registry.list();
497        let agents = names
498            .into_iter()
499            .map(|name| AgentInfo {
500                name,
501                description: "Enact agent".to_string(),
502                tools: vec![],
503            })
504            .collect();
505
506        Ok(Response::new(ListAgentsResponse { agents }))
507    }
508
509    async fn health_check(
510        &self,
511        _request: Request<HealthCheckRequest>,
512    ) -> Result<Response<HealthCheckResponse>, Status> {
513        Ok(Response::new(HealthCheckResponse {
514            healthy: true,
515            version: env!("CARGO_PKG_VERSION").to_string(),
516        }))
517    }
518
519    async fn cancel(
520        &self,
521        request: Request<CancelRequest>,
522    ) -> Result<Response<CancelResponse>, Status> {
523        let req = request.into_inner();
524        let execution_id = ExecutionId::from_string(&req.run_id);
525
526        info!("Cancel request for execution: {}", execution_id);
527
528        // Create control message with Cancel action
529        let control_msg =
530            ControlMessage::new(execution_id.clone(), ControlAction::Cancel, "grpc_api");
531
532        // Add reason if provided
533        let control_msg = if let Some(reason) = req.reason {
534            control_msg.with_reason(reason)
535        } else {
536            control_msg
537        };
538
539        // Push to inbox store
540        self.state
541            .inbox_store
542            .push(&execution_id, InboxMessage::Control(control_msg));
543
544        info!(
545            "Cancel message pushed to inbox for execution: {}",
546            execution_id
547        );
548
549        Ok(Response::new(CancelResponse {
550            success: true,
551            run_id: req.run_id,
552            message: Some("Cancel request sent".to_string()),
553        }))
554    }
555
556    async fn pause(
557        &self,
558        request: Request<PauseRequest>,
559    ) -> Result<Response<PauseResponse>, Status> {
560        let req = request.into_inner();
561        let execution_id = ExecutionId::from_string(&req.run_id);
562
563        info!("Pause request for execution: {}", execution_id);
564
565        // Create control message with Pause action
566        let control_msg =
567            ControlMessage::new(execution_id.clone(), ControlAction::Pause, "grpc_api");
568
569        // Push to inbox store
570        self.state
571            .inbox_store
572            .push(&execution_id, InboxMessage::Control(control_msg));
573
574        info!(
575            "Pause message pushed to inbox for execution: {}",
576            execution_id
577        );
578
579        Ok(Response::new(PauseResponse {
580            success: true,
581            run_id: req.run_id,
582            checkpoint_id: None, // Checkpoint will be created by the execution loop
583        }))
584    }
585
586    type ResumeStream = Pin<Box<dyn Stream<Item = Result<StreamEvent, Status>> + Send>>;
587
588    async fn resume(
589        &self,
590        request: Request<ResumeExecutionRequest>,
591    ) -> Result<Response<Self::ResumeStream>, Status> {
592        let req = request.into_inner();
593        let execution_id = ExecutionId::from_string(&req.run_id);
594
595        info!("Resume request for execution: {}", execution_id);
596
597        // Create control message with Resume action
598        let control_msg =
599            ControlMessage::new(execution_id.clone(), ControlAction::Resume, "grpc_api");
600
601        // Push to inbox store
602        self.state
603            .inbox_store
604            .push(&execution_id, InboxMessage::Control(control_msg));
605
606        info!(
607            "Resume message pushed to inbox for execution: {}",
608            execution_id
609        );
610
611        // Create a stream that will emit events as the execution resumes
612        let (tx, rx) = tokio::sync::mpsc::channel(100);
613        let run_id = req.run_id.clone();
614
615        // Send initial event indicating resume was requested
616        tokio::spawn(async move {
617            let _ = tx
618                .send(Ok(StreamEvent {
619                    run_id: run_id.clone(),
620                    event: Some(stream_event::Event::RunStarted(RunStarted {})),
621                }))
622                .await;
623
624            // Note: The actual execution resumption will be handled by the execution loop
625            // when it processes the Resume control message from the inbox.
626            // This stream currently just acknowledges the resume request.
627            // A full implementation would connect to the actual execution stream.
628        });
629
630        let stream = ReceiverStream::new(rx);
631        Ok(Response::new(Box::pin(stream)))
632    }
633
634    async fn approve_plan(
635        &self,
636        request: Request<ApprovePlanRequest>,
637    ) -> Result<Response<ApprovePlanResponse>, Status> {
638        let req = request.into_inner();
639        let execution_id = ExecutionId::from_string(&req.run_id);
640
641        info!("Approve plan request for execution: {}", execution_id);
642
643        // Create guidance message approving the plan
644        let guidance_msg = GuidanceMessage::from_user(
645            execution_id.clone(),
646            "PLAN_APPROVED: User approved the proposed plan. Proceed with execution.",
647        )
648        .with_priority(GuidancePriority::High);
649
650        // Push to inbox store
651        self.state
652            .inbox_store
653            .push(&execution_id, InboxMessage::Guidance(guidance_msg));
654
655        info!(
656            "Plan approval message pushed to inbox for execution: {}",
657            execution_id
658        );
659
660        Ok(Response::new(ApprovePlanResponse {
661            success: true,
662            run_id: req.run_id,
663        }))
664    }
665
666    async fn reject_plan(
667        &self,
668        request: Request<RejectPlanRequest>,
669    ) -> Result<Response<RejectPlanResponse>, Status> {
670        let req = request.into_inner();
671        let execution_id = ExecutionId::from_string(&req.run_id);
672
673        info!("Reject plan request for execution: {}", execution_id);
674
675        // Create guidance message rejecting the plan
676        let reason = req
677            .reason
678            .unwrap_or_else(|| "No reason provided".to_string());
679        let guidance_msg = GuidanceMessage::from_user(
680            execution_id.clone(),
681            format!(
682                "PLAN_REJECTED: User rejected the proposed plan. Reason: {}",
683                reason
684            ),
685        )
686        .with_priority(GuidancePriority::High);
687
688        // Push to inbox store
689        self.state
690            .inbox_store
691            .push(&execution_id, InboxMessage::Guidance(guidance_msg));
692
693        info!(
694            "Plan rejection message pushed to inbox for execution: {}",
695            execution_id
696        );
697
698        Ok(Response::new(RejectPlanResponse {
699            success: true,
700            run_id: req.run_id,
701        }))
702    }
703}
704
705pub async fn serve_grpc(state: Arc<GrpcState>, addr: &str) -> Result<(), anyhow::Error> {
706    let service = RuntimeServiceImpl::new(state);
707    let addr = addr.parse()?;
708
709    info!("gRPC server listening on {}", addr);
710
711    tonic::transport::Server::builder()
712        .add_service(RuntimeServiceServer::new(service))
713        .serve(addr)
714        .await?;
715
716    Ok(())
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use async_trait::async_trait;
723    use enact_core::inbox::InMemoryInboxStore;
724    use enact_core::providers::{ChatRequest, ChatResponse};
725    use tokio_stream::StreamExt;
726
727    struct TestProvider;
728
729    #[async_trait]
730    impl ModelProvider for TestProvider {
731        fn name(&self) -> &str {
732            "test"
733        }
734        fn model(&self) -> &str {
735            "test-model"
736        }
737        async fn chat(&self, _: ChatRequest) -> anyhow::Result<ChatResponse> {
738            anyhow::bail!("test only")
739        }
740    }
741
742    fn create_test_state() -> Arc<GrpcState> {
743        let inbox_store = InMemoryInboxStore::shared();
744        Arc::new(GrpcState::with_provider(
745            Arc::new(TestProvider),
746            inbox_store,
747        ))
748    }
749
750    #[tokio::test]
751    async fn test_health_check() {
752        let state = create_test_state();
753        let service = RuntimeServiceImpl::new(state);
754        let request = Request::new(HealthCheckRequest {});
755
756        let response = service.health_check(request).await.unwrap();
757        let health = response.into_inner();
758
759        assert!(health.healthy);
760        assert!(!health.version.is_empty());
761    }
762
763    #[tokio::test]
764    async fn test_list_agents() {
765        let state = create_test_state();
766        let service = RuntimeServiceImpl::new(state);
767        let request = Request::new(ListAgentsRequest {});
768
769        let response = service.list_agents(request).await.unwrap();
770        let list = response.into_inner();
771
772        assert!(!list.agents.is_empty());
773        // Should have at least assistant and coder
774        assert!(list.agents.iter().any(|a| a.name == "assistant"));
775        assert!(list.agents.iter().any(|a| a.name == "coder"));
776    }
777
778    #[tokio::test]
779    async fn test_get_agent_info() {
780        let state = create_test_state();
781        let service = RuntimeServiceImpl::new(state);
782        let request = Request::new(GetAgentInfoRequest {
783            agent_name: "assistant".to_string(),
784        });
785
786        let response = service.get_agent_info(request).await.unwrap();
787        let info = response.into_inner();
788
789        assert_eq!(info.name, "assistant");
790        assert!(!info.description.is_empty());
791    }
792
793    #[tokio::test]
794    async fn test_get_agent_info_not_found() {
795        let state = create_test_state();
796        let service = RuntimeServiceImpl::new(state);
797        let request = Request::new(GetAgentInfoRequest {
798            agent_name: "nonexistent".to_string(),
799        });
800
801        let result = service.get_agent_info(request).await;
802        assert!(result.is_err());
803
804        let err = result.unwrap_err();
805        assert_eq!(err.code(), tonic::Code::NotFound);
806    }
807
808    #[tokio::test]
809    async fn test_cancel_execution() {
810        let state = create_test_state();
811        let service = RuntimeServiceImpl::new(state);
812        let run_id = uuid::Uuid::new_v4().to_string();
813
814        let request = Request::new(CancelRequest {
815            run_id: run_id.clone(),
816            reason: Some("Test cancellation".to_string()),
817        });
818
819        let response = service.cancel(request).await.unwrap();
820        let cancel_resp = response.into_inner();
821
822        assert!(cancel_resp.success);
823        assert_eq!(cancel_resp.run_id, run_id);
824        assert!(cancel_resp.message.is_some());
825    }
826
827    #[tokio::test]
828    async fn test_pause_execution() {
829        let state = create_test_state();
830        let service = RuntimeServiceImpl::new(state);
831        let run_id = uuid::Uuid::new_v4().to_string();
832
833        let request = Request::new(PauseRequest {
834            run_id: run_id.clone(),
835        });
836
837        let response = service.pause(request).await.unwrap();
838        let pause_resp = response.into_inner();
839
840        assert!(pause_resp.success);
841        assert_eq!(pause_resp.run_id, run_id);
842    }
843
844    #[tokio::test]
845    async fn test_resume_execution() {
846        let state = create_test_state();
847        let service = RuntimeServiceImpl::new(state);
848        let run_id = uuid::Uuid::new_v4().to_string();
849
850        let request = Request::new(ResumeExecutionRequest {
851            run_id: run_id.clone(),
852            checkpoint_id: None,
853        });
854
855        let response = service.resume(request).await.unwrap();
856        let mut stream = response.into_inner();
857
858        // Should receive at least RunStarted event
859        let first_event = stream.next().await;
860        assert!(first_event.is_some());
861
862        let event = first_event.unwrap().unwrap();
863        assert_eq!(event.run_id, run_id);
864    }
865
866    #[tokio::test]
867    async fn test_approve_plan() {
868        let state = create_test_state();
869        let service = RuntimeServiceImpl::new(state);
870        let run_id = uuid::Uuid::new_v4().to_string();
871
872        let request = Request::new(ApprovePlanRequest {
873            run_id: run_id.clone(),
874        });
875
876        let response = service.approve_plan(request).await.unwrap();
877        let approve_resp = response.into_inner();
878
879        assert!(approve_resp.success);
880        assert_eq!(approve_resp.run_id, run_id);
881    }
882
883    #[tokio::test]
884    async fn test_reject_plan() {
885        let state = create_test_state();
886        let service = RuntimeServiceImpl::new(state);
887        let run_id = uuid::Uuid::new_v4().to_string();
888
889        let request = Request::new(RejectPlanRequest {
890            run_id: run_id.clone(),
891            reason: Some("Not aligned with goals".to_string()),
892        });
893
894        let response = service.reject_plan(request).await.unwrap();
895        let reject_resp = response.into_inner();
896
897        assert!(reject_resp.success);
898        assert_eq!(reject_resp.run_id, run_id);
899    }
900
901    #[tokio::test]
902    async fn test_run_agent_agent_not_found() {
903        let state = create_test_state();
904        let service = RuntimeServiceImpl::new(state);
905
906        let request = Request::new(RunAgentRequest {
907            agent_name: "nonexistent".to_string(),
908            input: "Hello".to_string(),
909            context: std::collections::HashMap::new(),
910            checkpoint_id: None,
911            background: false,
912            inherit_inbox: false,
913            parent_execution_id: None,
914        });
915
916        let result = service.run_agent(request).await;
917        assert!(result.is_err());
918
919        let err = result.unwrap_err();
920        assert_eq!(err.code(), tonic::Code::NotFound);
921    }
922}