1use crate::capabilities::CapabilityRegistry;
15use crate::command::CommandResult;
16use crate::driver_registry::{
17 BoxedChatDriver, DriverRegistry, LlmCallConfig, LlmCallConfigBuilder, LlmMessage,
18 LlmMessageRole, LlmResponseStream, ProviderConfig, ToolSearchConfig,
19};
20use crate::error::{AgentLoopError, Result};
21use crate::message::{Controls, Message, MessageRole, patch_dangling_tool_calls};
22use crate::message_retriever::MessageRetriever;
23use crate::runtime_context::{AssembledTurnContext, inspect_turn_context};
24use crate::session::Session;
25use crate::traits::{
26 AgentStore, HarnessStore, ImageResolver, ProviderStore, ResolvedImage, ResolvedModel,
27 SessionFileSystem, SessionStore,
28};
29use crate::typed_id::SessionId;
30use crate::user_facing_error::{UserFacingErrorContext, classify_runtime_error_message};
31use async_trait::async_trait;
32use std::collections::{HashMap, HashSet};
33use std::sync::Arc;
34use uuid::Uuid;
35
36#[derive(Debug, Clone)]
45pub struct CommandTurnContext {
46 pub session: Session,
48 pub messages: Vec<Message>,
50 pub system_prompt: String,
52 pub model: String,
54 pub provider_type: String,
56 pub resolved_locale: Option<String>,
58}
59
60#[derive(Debug, Clone, Default)]
65pub struct SessionCompletionRequest {
66 pub system_prompts: Vec<String>,
69 pub messages: Vec<Message>,
71 pub controls: Option<Controls>,
74 pub metadata: HashMap<String, String>,
77}
78
79#[derive(Debug, Clone)]
81pub struct SessionCompletion {
82 pub text: String,
84}
85
86pub struct SessionCompletionStream {
90 pub events: LlmResponseStream,
92 pub context: UserFacingErrorContext,
96}
97
98impl std::fmt::Debug for SessionCompletionStream {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("SessionCompletionStream")
101 .field("context", &self.context)
102 .finish()
103 }
104}
105
106#[derive(Debug)]
108pub enum SessionCompletionError {
109 InvalidRequest(AgentLoopError),
112 StreamingUnsupported,
116 Completion {
119 error: String,
121 context: UserFacingErrorContext,
123 },
124}
125
126impl SessionCompletionError {
127 pub fn into_command_result(self) -> Result<CommandResult> {
131 match self {
132 Self::InvalidRequest(error) => Err(error),
133 Self::StreamingUnsupported => Err(AgentLoopError::config(
134 "command host does not support streaming completions",
135 )),
136 Self::Completion { error, context } => {
137 let classified = classify_runtime_error_message(&error, &context);
138 Ok(CommandResult {
139 success: false,
140 message: classified.fallback_message(),
141 error_code: Some(classified.code.clone()),
142 error_fields: classified.error_fields(),
143 })
144 }
145 }
146 }
147}
148
149#[async_trait]
154pub trait CommandHost: Send + Sync {
155 async fn turn_context(&self) -> Result<CommandTurnContext>;
157
158 async fn completion(
161 &self,
162 request: SessionCompletionRequest,
163 ) -> std::result::Result<SessionCompletion, SessionCompletionError>;
164
165 async fn completion_stream(
173 &self,
174 _request: SessionCompletionRequest,
175 ) -> std::result::Result<SessionCompletionStream, SessionCompletionError> {
176 Err(SessionCompletionError::StreamingUnsupported)
177 }
178}
179
180pub struct DisabledCommandHost;
183
184#[async_trait]
185impl CommandHost for DisabledCommandHost {
186 async fn turn_context(&self) -> Result<CommandTurnContext> {
187 Err(AgentLoopError::config(
188 "command host does not provide turn-context access",
189 ))
190 }
191
192 async fn completion(
193 &self,
194 _request: SessionCompletionRequest,
195 ) -> std::result::Result<SessionCompletion, SessionCompletionError> {
196 Err(SessionCompletionError::InvalidRequest(
197 AgentLoopError::config("command host does not provide session completions"),
198 ))
199 }
200}
201
202pub struct StoreCommandHost {
210 session_id: SessionId,
211 harness_store: Arc<dyn HarnessStore>,
212 agent_store: Arc<dyn AgentStore>,
213 session_store: Arc<dyn SessionStore>,
214 message_retriever: Arc<dyn MessageRetriever>,
215 provider_store: Arc<dyn ProviderStore>,
216 capability_registry: CapabilityRegistry,
217 driver_registry: DriverRegistry,
218 image_resolver: Option<Arc<dyn ImageResolver>>,
219 file_store: Option<Arc<dyn SessionFileSystem>>,
220 assembled: tokio::sync::OnceCell<AssembledTurnContext>,
221}
222
223impl StoreCommandHost {
224 #[allow(clippy::too_many_arguments)]
225 pub fn new(
226 session_id: SessionId,
227 harness_store: Arc<dyn HarnessStore>,
228 agent_store: Arc<dyn AgentStore>,
229 session_store: Arc<dyn SessionStore>,
230 message_retriever: Arc<dyn MessageRetriever>,
231 provider_store: Arc<dyn ProviderStore>,
232 capability_registry: CapabilityRegistry,
233 driver_registry: DriverRegistry,
234 ) -> Self {
235 Self {
236 session_id,
237 harness_store,
238 agent_store,
239 session_store,
240 message_retriever,
241 provider_store,
242 capability_registry,
243 driver_registry,
244 image_resolver: None,
245 file_store: None,
246 assembled: tokio::sync::OnceCell::new(),
247 }
248 }
249
250 pub fn with_image_resolver(mut self, image_resolver: Arc<dyn ImageResolver>) -> Self {
253 self.image_resolver = Some(image_resolver);
254 self
255 }
256
257 pub fn with_file_store(mut self, file_store: Arc<dyn SessionFileSystem>) -> Self {
260 self.file_store = Some(file_store);
261 self
262 }
263
264 pub fn with_assembled_context(mut self, assembled: AssembledTurnContext) -> Self {
267 self.assembled = tokio::sync::OnceCell::new_with(Some(assembled));
268 self
269 }
270
271 async fn assembled(&self) -> Result<&AssembledTurnContext> {
272 self.assembled
273 .get_or_try_init(|| async {
274 let session = self
275 .session_store
276 .get_session(self.session_id)
277 .await?
278 .ok_or_else(|| AgentLoopError::session_not_found(self.session_id))?;
279 inspect_turn_context(
280 self.harness_store.as_ref(),
281 self.agent_store.as_ref(),
282 self.session_store.as_ref(),
283 self.message_retriever.as_ref(),
284 self.provider_store.as_ref(),
285 &self.capability_registry,
286 self.session_id,
287 session.harness_id,
288 session.agent_id,
289 &[],
290 self.file_store.clone(),
291 )
292 .await
293 })
294 .await
295 }
296
297 async fn resolve_images(&self, messages: &[Message]) -> HashMap<Uuid, ResolvedImage> {
298 let Some(resolver) = &self.image_resolver else {
299 return HashMap::new();
300 };
301 let image_ids: HashSet<Uuid> = messages
302 .iter()
303 .flat_map(LlmMessage::extract_image_file_ids)
304 .collect();
305 let mut resolved = HashMap::new();
306 for image_id in image_ids {
307 if let Ok(Some(image)) = resolver.resolve_image(image_id).await {
308 resolved.insert(image_id, image);
309 }
310 }
311 resolved
312 }
313
314 async fn resolve_completion_model(
318 &self,
319 controls: Option<&Controls>,
320 assembled: &AssembledTurnContext,
321 ) -> std::result::Result<ResolvedModel, SessionCompletionError> {
322 let requested = controls.and_then(|controls| controls.model_id);
323 match requested {
324 Some(model_id) if Some(model_id) != assembled.resolved_model_id => self
325 .provider_store
326 .get_resolved_model(model_id)
327 .await
328 .map_err(SessionCompletionError::InvalidRequest)?
329 .ok_or_else(|| {
330 SessionCompletionError::InvalidRequest(AgentLoopError::config(format!(
331 "Model not found: {model_id}"
332 )))
333 }),
334 _ => Ok(assembled.model_with_provider.clone()),
335 }
336 }
337
338 async fn prepare_completion(
342 &self,
343 request: SessionCompletionRequest,
344 ) -> std::result::Result<PreparedCompletion, SessionCompletionError> {
345 let assembled = self
346 .assembled()
347 .await
348 .map_err(SessionCompletionError::InvalidRequest)?;
349 let model = self
350 .resolve_completion_model(request.controls.as_ref(), assembled)
351 .await?;
352
353 let context = UserFacingErrorContext::default()
354 .with_provider(model.provider_type.to_string())
355 .with_model_id(model.model.clone());
356
357 let messages = patch_dangling_tool_calls(&request.messages);
358 let resolved_images = self.resolve_images(&messages).await;
359
360 let mut llm_messages: Vec<LlmMessage> = request
361 .system_prompts
362 .iter()
363 .filter(|prompt| !prompt.is_empty())
364 .map(|prompt| LlmMessage::text(LlmMessageRole::System, prompt.clone()))
365 .collect();
366 for msg in &messages {
367 let mut llm_msg = LlmMessage::from_message_with_images(msg, &resolved_images);
368 if msg.role == MessageRole::User
369 && let Some(actor) = &msg.external_actor
370 {
371 llm_msg.prepend_text_prefix(&format!("[{}] ", actor.display_label()));
372 }
373 llm_messages.push(llm_msg);
374 }
375
376 let mut llm_config_builder = LlmCallConfigBuilder::from(&assembled.runtime_agent)
377 .model(&model.model)
378 .tools(vec![])
379 .tool_search(ToolSearchConfig {
380 enabled: false,
381 threshold: usize::MAX,
382 })
383 .previous_response_id(None)
384 .with_metadata("session_id", self.session_id.to_string());
385 if let Some(effort) = request
386 .controls
387 .as_ref()
388 .and_then(|controls| controls.reasoning.as_ref())
389 .and_then(|reasoning| reasoning.effort.clone())
390 .filter(|value| !value.is_empty())
391 {
392 llm_config_builder = llm_config_builder.reasoning_effort(effort);
393 }
394 for (key, value) in &request.metadata {
395 llm_config_builder = llm_config_builder.with_metadata(key, value);
396 }
397 let llm_config = llm_config_builder.build();
398
399 let driver = self
400 .driver_registry
401 .create_chat_driver(&ProviderConfig::from(&model))
402 .map_err(|error| SessionCompletionError::Completion {
403 error: error.to_string(),
404 context: context.clone(),
405 })?;
406
407 Ok(PreparedCompletion {
408 llm_messages,
409 llm_config,
410 driver,
411 context,
412 })
413 }
414}
415
416struct PreparedCompletion {
419 llm_messages: Vec<LlmMessage>,
420 llm_config: LlmCallConfig,
421 driver: BoxedChatDriver,
422 context: UserFacingErrorContext,
423}
424
425#[async_trait]
426impl CommandHost for StoreCommandHost {
427 async fn turn_context(&self) -> Result<CommandTurnContext> {
428 let assembled = self.assembled().await?;
429 Ok(CommandTurnContext {
430 session: assembled.session.clone(),
431 messages: assembled.messages.clone(),
432 system_prompt: assembled.runtime_agent.system_prompt.clone(),
433 model: assembled.model_with_provider.model.clone(),
434 provider_type: assembled.model_with_provider.provider_type.to_string(),
435 resolved_locale: assembled.resolved_locale.clone(),
436 })
437 }
438
439 async fn completion(
440 &self,
441 request: SessionCompletionRequest,
442 ) -> std::result::Result<SessionCompletion, SessionCompletionError> {
443 let prepared = self.prepare_completion(request).await?;
444 let completion_error = |error: String| SessionCompletionError::Completion {
445 error,
446 context: prepared.context.clone(),
447 };
448
449 let response = prepared
450 .driver
451 .chat_completion(prepared.llm_messages, &prepared.llm_config)
452 .await
453 .map_err(|error| completion_error(error.to_string()))?;
454
455 let text = response.text.trim().to_string();
456 if text.is_empty() {
457 return Err(completion_error(
458 "session completion returned an empty response".to_string(),
459 ));
460 }
461 Ok(SessionCompletion { text })
462 }
463
464 async fn completion_stream(
465 &self,
466 request: SessionCompletionRequest,
467 ) -> std::result::Result<SessionCompletionStream, SessionCompletionError> {
468 let prepared = self.prepare_completion(request).await?;
469 let events = prepared
470 .driver
471 .chat_completion_stream(prepared.llm_messages, &prepared.llm_config)
472 .await
473 .map_err(|error| SessionCompletionError::Completion {
474 error: error.to_string(),
475 context: prepared.context.clone(),
476 })?;
477 Ok(SessionCompletionStream {
478 events,
479 context: prepared.context,
480 })
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::agent::{Agent, AgentStatus};
488 use crate::capabilities::TestMathCapability;
489 use crate::driver_registry::LlmStreamEvent;
490 use crate::harness::{Harness, HarnessStatus};
491 use crate::in_memory::{
492 InMemoryAgentStore, InMemoryHarnessStore, InMemoryMessageRetriever, InMemoryProviderStore,
493 InMemorySessionStore,
494 };
495 use crate::llmsim_driver::{LlmSimConfig, LlmSimDriver};
496 use crate::message_retriever::InputMessage;
497 use crate::provider::DriverId;
498 use crate::session::SessionStatus;
499 use crate::typed_id::{AgentId, HarnessId};
500 use chrono::Utc;
501 use futures::StreamExt;
502
503 #[tokio::test]
504 async fn disabled_host_errors_clearly() {
505 let host = DisabledCommandHost;
506 let error = host.turn_context().await.unwrap_err();
507 assert!(error.to_string().contains("turn-context"));
508
509 let error = host
510 .completion(SessionCompletionRequest::default())
511 .await
512 .unwrap_err();
513 assert!(matches!(error, SessionCompletionError::InvalidRequest(_)));
514
515 let error = host
518 .completion_stream(SessionCompletionRequest::default())
519 .await
520 .unwrap_err();
521 assert!(matches!(
522 error,
523 SessionCompletionError::StreamingUnsupported
524 ));
525 let error = error.into_command_result().unwrap_err();
526 assert!(error.to_string().contains("streaming"));
527 }
528
529 fn test_harness(harness_id: HarnessId) -> Harness {
530 Harness {
531 id: harness_id,
532 name: "h".into(),
533 display_name: None,
534 description: None,
535 system_prompt: Some("You are a test harness.".into()),
536 parent_harness_id: None,
537 default_model_id: None,
538 tags: vec![],
539 capabilities: vec![crate::AgentCapabilityConfig::new("test_math")],
540 initial_files: vec![],
541 network_access: None,
542 parallel_tool_calls: None,
543 mcp_servers: Default::default(),
544 embedder_metadata: Default::default(),
545 is_built_in: false,
546 status: HarnessStatus::Active,
547 created_at: Utc::now(),
548 updated_at: Utc::now(),
549 archived_at: None,
550 deleted_at: None,
551 }
552 }
553
554 fn test_agent(agent_id: AgentId) -> Agent {
555 Agent {
556 public_id: agent_id,
557 internal_id: uuid::Uuid::nil(),
558 name: "a".into(),
559 display_name: None,
560 description: None,
561 system_prompt: "Use tools.".into(),
562 default_model_id: None,
563 default_version_id: None,
564 forked_from_agent_id: None,
565 forked_from_version_id: None,
566 root_agent_id: None,
567 tags: vec![],
568 capabilities: vec![],
569 initial_files: vec![],
570 network_access: None,
571 max_iterations: Some(8),
572 parallel_tool_calls: None,
573 tools: vec![],
574 mcp_servers: Default::default(),
575 status: AgentStatus::Active,
576 created_at: Utc::now(),
577 updated_at: Utc::now(),
578 archived_at: None,
579 deleted_at: None,
580 usage: None,
581 }
582 }
583
584 fn test_session(session_id: SessionId, harness_id: HarnessId, agent_id: AgentId) -> Session {
585 Session {
586 id: session_id,
587 workspace_id: crate::WorkspaceId::from_uuid((session_id).uuid()),
588 organization_id: crate::DEFAULT_ORG_PUBLIC_ID.to_string(),
589 harness_id,
590 agent_id: Some(agent_id),
591 agent_version_id: None,
592 agent_identity_id: None,
593 owner_principal_id: crate::PrincipalId::from_seed(1),
594 resolved_owner_user_id: None,
595 owner: None,
596 effective_owner: None,
597 title: None,
598 locale: None,
599 preview: None,
600 output_preview: None,
601 tags: vec![],
602 model_id: None,
603 capabilities: vec![],
604 tools: vec![],
605 mcp_servers: Default::default(),
606 system_prompt: None,
607 initial_files: vec![],
608 hints: None,
609 network_access: None,
610 max_iterations: None,
611 parallel_tool_calls: None,
612 status: SessionStatus::Started,
613 created_at: Utc::now(),
614 updated_at: Utc::now(),
615 started_at: None,
616 finished_at: None,
617 usage: None,
618 is_pinned: None,
619 active_schedule_count: None,
620 features: vec![],
621 parent_session_id: None,
622 blueprint_id: None,
623 blueprint_config: None,
624 }
625 }
626
627 async fn llmsim_host(response: &str) -> StoreCommandHost {
630 let harness_id: HarnessId = "harness_000000000000000000000000000000a1".parse().unwrap();
631 let agent_id: AgentId = "agent_000000000000000000000000000000a1".parse().unwrap();
632 let session_id: SessionId = "session_000000000000000000000000000000a1".parse().unwrap();
633
634 let harness_store = InMemoryHarnessStore::new();
635 harness_store.add_harness(test_harness(harness_id)).await;
636 let agent_store = InMemoryAgentStore::new();
637 agent_store.add_agent(test_agent(agent_id)).await;
638 let session_store = InMemorySessionStore::new();
639 session_store
640 .add_session(test_session(session_id, harness_id, agent_id))
641 .await;
642 let message_store = InMemoryMessageRetriever::new();
643 message_store
644 .add(session_id, InputMessage::user("earlier message"))
645 .await
646 .unwrap();
647
648 let provider_store = InMemoryProviderStore::new();
649 provider_store
650 .set_default_model(ResolvedModel {
651 model: "llmsim-model".into(),
652 provider_type: DriverId::LlmSim,
653 api_key: Some("fake-key".into()),
654 base_url: None,
655 provider_metadata: None,
656 })
657 .await;
658
659 let mut capability_registry = CapabilityRegistry::new();
660 capability_registry.register(TestMathCapability);
661
662 let mut driver_registry = DriverRegistry::new();
663 let driver = LlmSimDriver::new(LlmSimConfig::fixed(response));
664 driver_registry.register(DriverId::LlmSim, move |_config| Box::new(driver.clone()));
665
666 StoreCommandHost::new(
667 session_id,
668 Arc::new(harness_store),
669 Arc::new(agent_store),
670 Arc::new(session_store),
671 Arc::new(message_store),
672 Arc::new(provider_store),
673 capability_registry,
674 driver_registry,
675 )
676 }
677
678 #[tokio::test]
679 async fn store_host_completion_runs_against_session_model() {
680 let host = llmsim_host("the side answer").await;
681
682 let turn = host.turn_context().await.unwrap();
683 assert_eq!(turn.model, "llmsim-model");
684 assert_eq!(turn.provider_type, "llmsim");
685 assert_eq!(turn.messages.len(), 1);
686 assert!(!turn.system_prompt.is_empty());
687
688 let completion = host
689 .completion(SessionCompletionRequest {
690 system_prompts: vec![turn.system_prompt, "Answer once.".into()],
691 messages: turn.messages,
692 controls: None,
693 metadata: HashMap::new(),
694 })
695 .await
696 .unwrap();
697 assert_eq!(completion.text, "the side answer");
698 }
699
700 #[tokio::test]
701 async fn store_host_completion_stream_emits_progressive_deltas() {
702 let host = llmsim_host("streamed side answer with several tokens").await;
703
704 let turn = host.turn_context().await.unwrap();
705 let stream = host
706 .completion_stream(SessionCompletionRequest {
707 system_prompts: vec![turn.system_prompt],
708 messages: turn.messages,
709 controls: None,
710 metadata: HashMap::new(),
711 })
712 .await
713 .unwrap();
714
715 assert_eq!(stream.context.provider.as_deref(), Some("llmsim"));
717 assert_eq!(stream.context.model_id.as_deref(), Some("llmsim-model"));
718
719 let mut deltas = Vec::new();
720 let mut done = false;
721 let mut events = stream.events;
722 while let Some(event) = events.next().await {
723 match event.unwrap() {
724 LlmStreamEvent::TextDelta(delta) => deltas.push(delta),
725 LlmStreamEvent::Done(_) => done = true,
726 _ => {}
727 }
728 }
729
730 assert!(done, "stream must terminate with Done");
731 assert!(
732 deltas.len() > 1,
733 "expected progressive deltas, got {deltas:?}"
734 );
735 assert_eq!(deltas.concat(), "streamed side answer with several tokens");
736 }
737
738 #[test]
739 fn completion_error_classifies_provider_failures() {
740 let error = SessionCompletionError::Completion {
741 error: "OpenAI API error (401): unauthorized".to_string(),
742 context: UserFacingErrorContext::default()
743 .with_provider("openai")
744 .with_model_id("gpt-5"),
745 };
746
747 let result = error.into_command_result().expect("classified result");
748 assert!(!result.success);
749 assert_eq!(result.error_code.as_deref(), Some("provider_misconfigured"));
750 let fields = result.error_fields.expect("error_fields populated");
751 assert_eq!(
752 fields.get("provider").and_then(|v| v.as_str()),
753 Some("openai")
754 );
755 assert_eq!(
756 fields.get("model_id").and_then(|v| v.as_str()),
757 Some("gpt-5")
758 );
759 }
760
761 #[test]
762 fn completion_error_bubbles_invalid_requests() {
763 let error =
764 SessionCompletionError::InvalidRequest(AgentLoopError::config("Model not found"));
765 assert!(error.into_command_result().is_err());
766 }
767}