1use crate::capabilities::CapabilityRegistry;
15use crate::command::CommandResult;
16use crate::error::{AgentLoopError, Result};
17use crate::llm_driver_registry::{
18 BoxedLlmDriver, DriverRegistry, LlmCallConfig, LlmCallConfigBuilder, LlmMessage,
19 LlmMessageRole, LlmResponseStream, ProviderConfig, ToolSearchConfig,
20};
21use crate::message::{Controls, Message, MessageRole, patch_dangling_tool_calls};
22use crate::message_retriever::MessageRetriever;
23use crate::runtime_context::{AssembledTurnContext, inspect_turn_context};
24use crate::session::Session;
25use crate::traits::{
26 AgentStore, HarnessStore, ImageResolver, LlmProviderStore, ModelWithProvider, ResolvedImage,
27 SessionFileSystem, SessionStore,
28};
29use crate::typed_id::SessionId;
30use crate::user_facing_error::{UserFacingErrorContext, classify_runtime_error_message};
31use async_trait::async_trait;
32use std::collections::{HashMap, HashSet};
33use std::sync::Arc;
34use uuid::Uuid;
35
36#[derive(Debug, Clone)]
45pub struct CommandTurnContext {
46 pub session: Session,
48 pub messages: Vec<Message>,
50 pub system_prompt: String,
52 pub model: String,
54 pub provider_type: String,
56 pub resolved_locale: Option<String>,
58}
59
60#[derive(Debug, Clone, Default)]
65pub struct SessionCompletionRequest {
66 pub system_prompts: Vec<String>,
69 pub messages: Vec<Message>,
71 pub controls: Option<Controls>,
74 pub metadata: HashMap<String, String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct SessionCompletion {
82 pub text: String,
84}
85
86pub struct SessionCompletionStream {
90 pub events: LlmResponseStream,
92 pub context: UserFacingErrorContext,
96}
97
98impl std::fmt::Debug for SessionCompletionStream {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("SessionCompletionStream")
101 .field("context", &self.context)
102 .finish()
103 }
104}
105
106#[derive(Debug)]
108pub enum SessionCompletionError {
109 InvalidRequest(AgentLoopError),
112 StreamingUnsupported,
116 Completion {
119 error: String,
121 context: UserFacingErrorContext,
123 },
124}
125
126impl SessionCompletionError {
127 pub fn into_command_result(self) -> Result<CommandResult> {
131 match self {
132 Self::InvalidRequest(error) => Err(error),
133 Self::StreamingUnsupported => Err(AgentLoopError::config(
134 "command host does not support streaming completions",
135 )),
136 Self::Completion { error, context } => {
137 let classified = classify_runtime_error_message(&error, &context);
138 Ok(CommandResult {
139 success: false,
140 message: classified.fallback_message(),
141 error_code: Some(classified.code.clone()),
142 error_fields: classified.error_fields(),
143 })
144 }
145 }
146 }
147}
148
149#[async_trait]
154pub trait CommandHost: Send + Sync {
155 async fn turn_context(&self) -> Result<CommandTurnContext>;
157
158 async fn completion(
161 &self,
162 request: SessionCompletionRequest,
163 ) -> std::result::Result<SessionCompletion, SessionCompletionError>;
164
165 async fn completion_stream(
173 &self,
174 _request: SessionCompletionRequest,
175 ) -> std::result::Result<SessionCompletionStream, SessionCompletionError> {
176 Err(SessionCompletionError::StreamingUnsupported)
177 }
178}
179
180pub struct DisabledCommandHost;
183
184#[async_trait]
185impl CommandHost for DisabledCommandHost {
186 async fn turn_context(&self) -> Result<CommandTurnContext> {
187 Err(AgentLoopError::config(
188 "command host does not provide turn-context access",
189 ))
190 }
191
192 async fn completion(
193 &self,
194 _request: SessionCompletionRequest,
195 ) -> std::result::Result<SessionCompletion, SessionCompletionError> {
196 Err(SessionCompletionError::InvalidRequest(
197 AgentLoopError::config("command host does not provide session completions"),
198 ))
199 }
200}
201
202pub struct StoreCommandHost {
210 session_id: SessionId,
211 harness_store: Arc<dyn HarnessStore>,
212 agent_store: Arc<dyn AgentStore>,
213 session_store: Arc<dyn SessionStore>,
214 message_retriever: Arc<dyn MessageRetriever>,
215 provider_store: Arc<dyn LlmProviderStore>,
216 capability_registry: CapabilityRegistry,
217 driver_registry: DriverRegistry,
218 image_resolver: Option<Arc<dyn ImageResolver>>,
219 file_store: Option<Arc<dyn SessionFileSystem>>,
220 assembled: tokio::sync::OnceCell<AssembledTurnContext>,
221}
222
223impl StoreCommandHost {
224 #[allow(clippy::too_many_arguments)]
225 pub fn new(
226 session_id: SessionId,
227 harness_store: Arc<dyn HarnessStore>,
228 agent_store: Arc<dyn AgentStore>,
229 session_store: Arc<dyn SessionStore>,
230 message_retriever: Arc<dyn MessageRetriever>,
231 provider_store: Arc<dyn LlmProviderStore>,
232 capability_registry: CapabilityRegistry,
233 driver_registry: DriverRegistry,
234 ) -> Self {
235 Self {
236 session_id,
237 harness_store,
238 agent_store,
239 session_store,
240 message_retriever,
241 provider_store,
242 capability_registry,
243 driver_registry,
244 image_resolver: None,
245 file_store: None,
246 assembled: tokio::sync::OnceCell::new(),
247 }
248 }
249
250 pub fn with_image_resolver(mut self, image_resolver: Arc<dyn ImageResolver>) -> Self {
253 self.image_resolver = Some(image_resolver);
254 self
255 }
256
257 pub fn with_file_store(mut self, file_store: Arc<dyn SessionFileSystem>) -> Self {
260 self.file_store = Some(file_store);
261 self
262 }
263
264 pub fn with_assembled_context(mut self, assembled: AssembledTurnContext) -> Self {
267 self.assembled = tokio::sync::OnceCell::new_with(Some(assembled));
268 self
269 }
270
271 async fn assembled(&self) -> Result<&AssembledTurnContext> {
272 self.assembled
273 .get_or_try_init(|| async {
274 let session = self
275 .session_store
276 .get_session(self.session_id)
277 .await?
278 .ok_or_else(|| AgentLoopError::session_not_found(self.session_id))?;
279 inspect_turn_context(
280 self.harness_store.as_ref(),
281 self.agent_store.as_ref(),
282 self.session_store.as_ref(),
283 self.message_retriever.as_ref(),
284 self.provider_store.as_ref(),
285 &self.capability_registry,
286 self.session_id,
287 session.harness_id,
288 session.agent_id,
289 &[],
290 self.file_store.clone(),
291 )
292 .await
293 })
294 .await
295 }
296
297 async fn resolve_images(&self, messages: &[Message]) -> HashMap<Uuid, ResolvedImage> {
298 let Some(resolver) = &self.image_resolver else {
299 return HashMap::new();
300 };
301 let image_ids: HashSet<Uuid> = messages
302 .iter()
303 .flat_map(LlmMessage::extract_image_file_ids)
304 .collect();
305 let mut resolved = HashMap::new();
306 for image_id in image_ids {
307 if let Ok(Some(image)) = resolver.resolve_image(image_id).await {
308 resolved.insert(image_id, image);
309 }
310 }
311 resolved
312 }
313
314 async fn resolve_completion_model(
318 &self,
319 controls: Option<&Controls>,
320 assembled: &AssembledTurnContext,
321 ) -> std::result::Result<ModelWithProvider, SessionCompletionError> {
322 let requested = controls.and_then(|controls| controls.model_id);
323 match requested {
324 Some(model_id) if Some(model_id) != assembled.resolved_model_id => self
325 .provider_store
326 .get_model_with_provider(model_id)
327 .await
328 .map_err(SessionCompletionError::InvalidRequest)?
329 .ok_or_else(|| {
330 SessionCompletionError::InvalidRequest(AgentLoopError::config(format!(
331 "Model not found: {model_id}"
332 )))
333 }),
334 _ => Ok(assembled.model_with_provider.clone()),
335 }
336 }
337
338 async fn prepare_completion(
342 &self,
343 request: SessionCompletionRequest,
344 ) -> std::result::Result<PreparedCompletion, SessionCompletionError> {
345 let assembled = self
346 .assembled()
347 .await
348 .map_err(SessionCompletionError::InvalidRequest)?;
349 let model = self
350 .resolve_completion_model(request.controls.as_ref(), assembled)
351 .await?;
352
353 let context = UserFacingErrorContext::default()
354 .with_provider(model.provider_type.to_string())
355 .with_model_id(model.model.clone());
356
357 let messages = patch_dangling_tool_calls(&request.messages);
358 let resolved_images = self.resolve_images(&messages).await;
359
360 let mut llm_messages: Vec<LlmMessage> = request
361 .system_prompts
362 .iter()
363 .filter(|prompt| !prompt.is_empty())
364 .map(|prompt| LlmMessage::text(LlmMessageRole::System, prompt.clone()))
365 .collect();
366 for msg in &messages {
367 let mut llm_msg = LlmMessage::from_message_with_images(msg, &resolved_images);
368 if msg.role == MessageRole::User
369 && let Some(actor) = &msg.external_actor
370 {
371 llm_msg.prepend_text_prefix(&format!("[{}] ", actor.display_label()));
372 }
373 llm_messages.push(llm_msg);
374 }
375
376 let mut llm_config_builder = LlmCallConfigBuilder::from(&assembled.runtime_agent)
377 .model(&model.model)
378 .tools(vec![])
379 .tool_search(ToolSearchConfig {
380 enabled: false,
381 threshold: usize::MAX,
382 })
383 .previous_response_id(None)
384 .with_metadata("session_id", self.session_id.to_string());
385 if let Some(effort) = request
386 .controls
387 .as_ref()
388 .and_then(|controls| controls.reasoning.as_ref())
389 .and_then(|reasoning| reasoning.effort.clone())
390 .filter(|value| !value.is_empty())
391 {
392 llm_config_builder = llm_config_builder.reasoning_effort(effort);
393 }
394 for (key, value) in &request.metadata {
395 llm_config_builder = llm_config_builder.with_metadata(key, value);
396 }
397 let llm_config = llm_config_builder.build();
398
399 let driver = self
400 .driver_registry
401 .create_driver(&ProviderConfig::from(&model))
402 .map_err(|error| SessionCompletionError::Completion {
403 error: error.to_string(),
404 context: context.clone(),
405 })?;
406
407 Ok(PreparedCompletion {
408 llm_messages,
409 llm_config,
410 driver,
411 context,
412 })
413 }
414}
415
416struct PreparedCompletion {
419 llm_messages: Vec<LlmMessage>,
420 llm_config: LlmCallConfig,
421 driver: BoxedLlmDriver,
422 context: UserFacingErrorContext,
423}
424
425#[async_trait]
426impl CommandHost for StoreCommandHost {
427 async fn turn_context(&self) -> Result<CommandTurnContext> {
428 let assembled = self.assembled().await?;
429 Ok(CommandTurnContext {
430 session: assembled.session.clone(),
431 messages: assembled.messages.clone(),
432 system_prompt: assembled.runtime_agent.system_prompt.clone(),
433 model: assembled.model_with_provider.model.clone(),
434 provider_type: assembled.model_with_provider.provider_type.to_string(),
435 resolved_locale: assembled.resolved_locale.clone(),
436 })
437 }
438
439 async fn completion(
440 &self,
441 request: SessionCompletionRequest,
442 ) -> std::result::Result<SessionCompletion, SessionCompletionError> {
443 let prepared = self.prepare_completion(request).await?;
444 let completion_error = |error: String| SessionCompletionError::Completion {
445 error,
446 context: prepared.context.clone(),
447 };
448
449 let response = prepared
450 .driver
451 .chat_completion(prepared.llm_messages, &prepared.llm_config)
452 .await
453 .map_err(|error| completion_error(error.to_string()))?;
454
455 let text = response.text.trim().to_string();
456 if text.is_empty() {
457 return Err(completion_error(
458 "session completion returned an empty response".to_string(),
459 ));
460 }
461 Ok(SessionCompletion { text })
462 }
463
464 async fn completion_stream(
465 &self,
466 request: SessionCompletionRequest,
467 ) -> std::result::Result<SessionCompletionStream, SessionCompletionError> {
468 let prepared = self.prepare_completion(request).await?;
469 let events = prepared
470 .driver
471 .chat_completion_stream(prepared.llm_messages, &prepared.llm_config)
472 .await
473 .map_err(|error| SessionCompletionError::Completion {
474 error: error.to_string(),
475 context: prepared.context.clone(),
476 })?;
477 Ok(SessionCompletionStream {
478 events,
479 context: prepared.context,
480 })
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::agent::{Agent, AgentStatus};
488 use crate::capabilities::TestMathCapability;
489 use crate::harness::{Harness, HarnessStatus};
490 use crate::in_memory::{
491 InMemoryAgentStore, InMemoryHarnessStore, InMemoryLlmProviderStore,
492 InMemoryMessageRetriever, InMemorySessionStore,
493 };
494 use crate::llm_driver_registry::{LlmStreamEvent, ProviderType};
495 use crate::llm_models::LlmProviderType;
496 use crate::llmsim_driver::{LlmSimConfig, LlmSimDriver};
497 use crate::message_retriever::InputMessage;
498 use crate::session::SessionStatus;
499 use crate::typed_id::{AgentId, HarnessId};
500 use chrono::Utc;
501 use futures::StreamExt;
502
503 #[tokio::test]
504 async fn disabled_host_errors_clearly() {
505 let host = DisabledCommandHost;
506 let error = host.turn_context().await.unwrap_err();
507 assert!(error.to_string().contains("turn-context"));
508
509 let error = host
510 .completion(SessionCompletionRequest::default())
511 .await
512 .unwrap_err();
513 assert!(matches!(error, SessionCompletionError::InvalidRequest(_)));
514
515 let error = host
518 .completion_stream(SessionCompletionRequest::default())
519 .await
520 .unwrap_err();
521 assert!(matches!(
522 error,
523 SessionCompletionError::StreamingUnsupported
524 ));
525 let error = error.into_command_result().unwrap_err();
526 assert!(error.to_string().contains("streaming"));
527 }
528
529 fn test_harness(harness_id: HarnessId) -> Harness {
530 Harness {
531 id: harness_id,
532 name: "h".into(),
533 display_name: None,
534 description: None,
535 system_prompt: "You are a test harness.".into(),
536 parent_harness_id: None,
537 default_model_id: None,
538 tags: vec![],
539 capabilities: vec![crate::AgentCapabilityConfig::new("test_math")],
540 initial_files: vec![],
541 network_access: None,
542 mcp_servers: Default::default(),
543 is_built_in: false,
544 status: HarnessStatus::Active,
545 created_at: Utc::now(),
546 updated_at: Utc::now(),
547 archived_at: None,
548 deleted_at: None,
549 }
550 }
551
552 fn test_agent(agent_id: AgentId) -> Agent {
553 Agent {
554 public_id: agent_id,
555 internal_id: uuid::Uuid::nil(),
556 name: "a".into(),
557 display_name: None,
558 description: None,
559 system_prompt: "Use tools.".into(),
560 default_model_id: None,
561 default_version_id: None,
562 forked_from_agent_id: None,
563 forked_from_version_id: None,
564 root_agent_id: None,
565 tags: vec![],
566 capabilities: vec![],
567 initial_files: vec![],
568 network_access: None,
569 max_iterations: Some(8),
570 tools: vec![],
571 mcp_servers: Default::default(),
572 status: AgentStatus::Active,
573 created_at: Utc::now(),
574 updated_at: Utc::now(),
575 archived_at: None,
576 deleted_at: None,
577 usage: None,
578 }
579 }
580
581 fn test_session(session_id: SessionId, harness_id: HarnessId, agent_id: AgentId) -> Session {
582 Session {
583 id: session_id,
584 organization_id: crate::DEFAULT_ORG_PUBLIC_ID.to_string(),
585 harness_id,
586 agent_id: Some(agent_id),
587 agent_version_id: None,
588 agent_identity_id: None,
589 owner_principal_id: crate::PrincipalId::from_seed(1),
590 resolved_owner_user_id: None,
591 owner: None,
592 effective_owner: None,
593 title: None,
594 locale: None,
595 preview: None,
596 output_preview: None,
597 tags: vec![],
598 model_id: None,
599 capabilities: vec![],
600 tools: vec![],
601 mcp_servers: Default::default(),
602 system_prompt: None,
603 initial_files: vec![],
604 hints: None,
605 network_access: None,
606 max_iterations: None,
607 status: SessionStatus::Started,
608 created_at: Utc::now(),
609 updated_at: Utc::now(),
610 started_at: None,
611 finished_at: None,
612 usage: None,
613 is_pinned: None,
614 active_schedule_count: None,
615 features: vec![],
616 parent_session_id: None,
617 subagent_name: None,
618 subagent_task: None,
619 subagent_status: None,
620 blueprint_id: None,
621 blueprint_config: None,
622 }
623 }
624
625 async fn llmsim_host(response: &str) -> StoreCommandHost {
628 let harness_id: HarnessId = "harness_000000000000000000000000000000a1".parse().unwrap();
629 let agent_id: AgentId = "agent_000000000000000000000000000000a1".parse().unwrap();
630 let session_id: SessionId = "session_000000000000000000000000000000a1".parse().unwrap();
631
632 let harness_store = InMemoryHarnessStore::new();
633 harness_store.add_harness(test_harness(harness_id)).await;
634 let agent_store = InMemoryAgentStore::new();
635 agent_store.add_agent(test_agent(agent_id)).await;
636 let session_store = InMemorySessionStore::new();
637 session_store
638 .add_session(test_session(session_id, harness_id, agent_id))
639 .await;
640 let message_store = InMemoryMessageRetriever::new();
641 message_store
642 .add(session_id, InputMessage::user("earlier message"))
643 .await
644 .unwrap();
645
646 let provider_store = InMemoryLlmProviderStore::new();
647 provider_store
648 .set_default_model(ModelWithProvider {
649 model: "llmsim-model".into(),
650 provider_type: LlmProviderType::LlmSim,
651 api_key: Some("fake-key".into()),
652 base_url: None,
653 })
654 .await;
655
656 let mut capability_registry = CapabilityRegistry::new();
657 capability_registry.register(TestMathCapability);
658
659 let mut driver_registry = DriverRegistry::new();
660 let driver = LlmSimDriver::new(LlmSimConfig::fixed(response));
661 driver_registry.register(ProviderType::LlmSim, move |_api_key, _base_url| {
662 Box::new(driver.clone())
663 });
664
665 StoreCommandHost::new(
666 session_id,
667 Arc::new(harness_store),
668 Arc::new(agent_store),
669 Arc::new(session_store),
670 Arc::new(message_store),
671 Arc::new(provider_store),
672 capability_registry,
673 driver_registry,
674 )
675 }
676
677 #[tokio::test]
678 async fn store_host_completion_runs_against_session_model() {
679 let host = llmsim_host("the side answer").await;
680
681 let turn = host.turn_context().await.unwrap();
682 assert_eq!(turn.model, "llmsim-model");
683 assert_eq!(turn.provider_type, "llmsim");
684 assert_eq!(turn.messages.len(), 1);
685 assert!(!turn.system_prompt.is_empty());
686
687 let completion = host
688 .completion(SessionCompletionRequest {
689 system_prompts: vec![turn.system_prompt, "Answer once.".into()],
690 messages: turn.messages,
691 controls: None,
692 metadata: HashMap::new(),
693 })
694 .await
695 .unwrap();
696 assert_eq!(completion.text, "the side answer");
697 }
698
699 #[tokio::test]
700 async fn store_host_completion_stream_emits_progressive_deltas() {
701 let host = llmsim_host("streamed side answer with several tokens").await;
702
703 let turn = host.turn_context().await.unwrap();
704 let stream = host
705 .completion_stream(SessionCompletionRequest {
706 system_prompts: vec![turn.system_prompt],
707 messages: turn.messages,
708 controls: None,
709 metadata: HashMap::new(),
710 })
711 .await
712 .unwrap();
713
714 assert_eq!(stream.context.provider.as_deref(), Some("llmsim"));
716 assert_eq!(stream.context.model_id.as_deref(), Some("llmsim-model"));
717
718 let mut deltas = Vec::new();
719 let mut done = false;
720 let mut events = stream.events;
721 while let Some(event) = events.next().await {
722 match event.unwrap() {
723 LlmStreamEvent::TextDelta(delta) => deltas.push(delta),
724 LlmStreamEvent::Done(_) => done = true,
725 _ => {}
726 }
727 }
728
729 assert!(done, "stream must terminate with Done");
730 assert!(
731 deltas.len() > 1,
732 "expected progressive deltas, got {deltas:?}"
733 );
734 assert_eq!(deltas.concat(), "streamed side answer with several tokens");
735 }
736
737 #[test]
738 fn completion_error_classifies_provider_failures() {
739 let error = SessionCompletionError::Completion {
740 error: "OpenAI API error (401): unauthorized".to_string(),
741 context: UserFacingErrorContext::default()
742 .with_provider("openai")
743 .with_model_id("gpt-5"),
744 };
745
746 let result = error.into_command_result().expect("classified result");
747 assert!(!result.success);
748 assert_eq!(result.error_code.as_deref(), Some("provider_misconfigured"));
749 let fields = result.error_fields.expect("error_fields populated");
750 assert_eq!(
751 fields.get("provider").and_then(|v| v.as_str()),
752 Some("openai")
753 );
754 assert_eq!(
755 fields.get("model_id").and_then(|v| v.as_str()),
756 Some("gpt-5")
757 );
758 }
759
760 #[test]
761 fn completion_error_bubbles_invalid_requests() {
762 let error =
763 SessionCompletionError::InvalidRequest(AgentLoopError::config("Model not found"));
764 assert!(error.into_command_result().is_err());
765 }
766}