1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use zeph_common::ToolName;
16
17pub use zeph_common::ToolDefinition;
18
19use crate::embed::owned_strs;
20use crate::error::LlmError;
21
22static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
23 LazyLock::new(|| Mutex::new(HashMap::new()));
24
25pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
31-> Result<(serde_json::Value, String), crate::LlmError> {
32 let type_id = TypeId::of::<T>();
33 if let Ok(cache) = SCHEMA_CACHE.lock()
34 && let Some(entry) = cache.get(&type_id)
35 {
36 return Ok(entry.clone());
37 }
38 let schema = schemars::schema_for!(T);
39 let value = serde_json::to_value(&schema)
40 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
41 let pretty = serde_json::to_string_pretty(&schema)
42 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
43 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
44 cache.insert(type_id, (value.clone(), pretty.clone()));
45 }
46 Ok((value, pretty))
47}
48
49pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
63 std::any::type_name::<T>()
64 .rsplit("::")
65 .next()
66 .unwrap_or("Output")
67}
68
69#[non_exhaustive]
78#[derive(Debug, Clone, Default)]
79pub struct ChatExtras {
80 pub entropy: Option<f64>,
85}
86
87impl ChatExtras {
88 #[must_use]
101 pub fn with_entropy(entropy: f64) -> Self {
102 Self {
103 entropy: Some(entropy),
104 }
105 }
106}
107
108#[non_exhaustive]
113#[derive(Debug, Clone)]
114pub enum StreamChunk {
115 Content(String),
117 Thinking(String),
119 Compaction(String),
122 ToolUse(Vec<ToolUseRequest>),
124}
125
126pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ToolUseRequest {
139 pub id: String,
141 pub name: ToolName,
143 pub input: serde_json::Value,
145}
146
147#[non_exhaustive]
153#[derive(Debug, Clone)]
154pub enum ThinkingBlock {
155 Thinking { thinking: String, signature: String },
157 Redacted { data: String },
159}
160
161pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
164
165#[non_exhaustive]
173#[derive(Debug, Clone)]
174pub enum ChatResponse {
175 Text(String),
177 ToolUse {
179 text: Option<String>,
181 tool_calls: Vec<ToolUseRequest>,
182 thinking_blocks: Vec<ThinkingBlock>,
185 },
186}
187
188pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
190
191pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
196
197pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
203
204#[must_use]
207pub fn default_debug_request_json(
208 messages: &[Message],
209 tools: &[ToolDefinition],
210) -> serde_json::Value {
211 serde_json::json!({
212 "model": serde_json::Value::Null,
213 "max_tokens": serde_json::Value::Null,
214 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
215 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
216 "temperature": serde_json::Value::Null,
217 "cache_control": serde_json::Value::Null,
218 })
219}
220
221#[derive(Debug, Clone, Default)]
230pub struct GenerationOverrides {
231 pub temperature: Option<f64>,
233 pub top_p: Option<f64>,
235 pub top_k: Option<usize>,
237 pub frequency_penalty: Option<f64>,
239 pub presence_penalty: Option<f64>,
241}
242
243#[non_exhaustive]
250#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum Role {
253 System,
254 User,
255 Assistant,
256}
257
258#[non_exhaustive]
273#[derive(Clone, Debug, Serialize, Deserialize)]
274#[serde(tag = "kind", rename_all = "snake_case")]
275pub enum MessagePart {
276 Text { text: String },
278 ToolOutput {
280 tool_name: zeph_common::ToolName,
281 body: String,
282 #[serde(default, skip_serializing_if = "Option::is_none")]
283 compacted_at: Option<i64>,
284 },
285 Recall { text: String },
287 CodeContext { text: String },
289 Summary { text: String },
291 CrossSession { text: String },
293 ToolUse {
295 id: String,
296 name: String,
297 input: serde_json::Value,
298 },
299 ToolResult {
301 tool_use_id: String,
302 content: String,
303 #[serde(default)]
304 is_error: bool,
305 },
306 Image(Box<ImageData>),
308 ThinkingBlock { thinking: String, signature: String },
310 RedactedThinkingBlock { data: String },
312 Compaction { summary: String },
315}
316
317impl MessagePart {
318 #[must_use]
321 pub fn as_plain_text(&self) -> Option<&str> {
322 match self {
323 Self::Text { text }
324 | Self::Recall { text }
325 | Self::CodeContext { text }
326 | Self::Summary { text }
327 | Self::CrossSession { text } => Some(text.as_str()),
328 _ => None,
329 }
330 }
331
332 #[must_use]
334 pub fn as_image(&self) -> Option<&ImageData> {
335 if let Self::Image(img) = self {
336 Some(img)
337 } else {
338 None
339 }
340 }
341}
342
343#[derive(Clone, Debug, Serialize, Deserialize)]
344pub struct ImageData {
349 #[serde(with = "serde_bytes_base64")]
350 pub data: Vec<u8>,
351 pub mime_type: String,
352}
353
354mod serde_bytes_base64 {
355 use base64::{Engine, engine::general_purpose::STANDARD};
356 use serde::{Deserialize, Deserializer, Serializer};
357
358 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
359 where
360 S: Serializer,
361 {
362 s.serialize_str(&STANDARD.encode(bytes))
363 }
364
365 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
366 where
367 D: Deserializer<'de>,
368 {
369 let s = String::deserialize(d)?;
370 STANDARD.decode(&s).map_err(serde::de::Error::custom)
371 }
372}
373
374#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
390#[serde(rename_all = "snake_case")]
391#[non_exhaustive]
392pub enum MessageVisibility {
393 Both,
395 AgentOnly,
397 UserOnly,
399}
400
401impl MessageVisibility {
402 #[must_use]
404 pub fn is_agent_visible(self) -> bool {
405 matches!(self, MessageVisibility::Both | MessageVisibility::AgentOnly)
406 }
407
408 #[must_use]
410 pub fn is_user_visible(self) -> bool {
411 matches!(self, MessageVisibility::Both | MessageVisibility::UserOnly)
412 }
413}
414
415impl Default for MessageVisibility {
416 fn default() -> Self {
418 MessageVisibility::Both
419 }
420}
421
422impl MessageVisibility {
423 #[must_use]
425 pub fn as_db_str(self) -> &'static str {
426 match self {
427 MessageVisibility::Both => "both",
428 MessageVisibility::AgentOnly => "agent_only",
429 MessageVisibility::UserOnly => "user_only",
430 }
431 }
432
433 #[must_use]
437 pub fn from_db_str(s: &str) -> Self {
438 match s {
439 "agent_only" => MessageVisibility::AgentOnly,
440 "user_only" => MessageVisibility::UserOnly,
441 _ => MessageVisibility::Both,
442 }
443 }
444}
445
446#[derive(Clone, Debug, Serialize, Deserialize)]
451pub struct MessageMetadata {
452 pub visibility: MessageVisibility,
454 #[serde(default, skip_serializing_if = "Option::is_none")]
456 pub compacted_at: Option<i64>,
457 #[serde(default, skip_serializing_if = "Option::is_none")]
460 pub deferred_summary: Option<String>,
461 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
464 pub focus_pinned: bool,
465 #[serde(default, skip_serializing_if = "Option::is_none")]
468 pub focus_marker_id: Option<uuid::Uuid>,
469 #[serde(skip)]
472 pub db_id: Option<i64>,
473 #[serde(default, skip_serializing_if = "Option::is_none")]
478 pub fidelity_tag: Option<zeph_common::ContextFidelity>,
479 #[serde(skip)]
483 pub embedding: Option<Vec<f32>>,
484}
485
486impl Default for MessageMetadata {
487 fn default() -> Self {
488 Self {
489 visibility: MessageVisibility::Both,
490 compacted_at: None,
491 deferred_summary: None,
492 focus_pinned: false,
493 focus_marker_id: None,
494 db_id: None,
495 fidelity_tag: None,
496 embedding: None,
497 }
498 }
499}
500
501impl MessageMetadata {
502 #[must_use]
504 pub fn agent_only() -> Self {
505 Self {
506 visibility: MessageVisibility::AgentOnly,
507 compacted_at: None,
508 deferred_summary: None,
509 focus_pinned: false,
510 focus_marker_id: None,
511 db_id: None,
512 fidelity_tag: None,
513 embedding: None,
514 }
515 }
516
517 #[must_use]
519 pub fn user_only() -> Self {
520 Self {
521 visibility: MessageVisibility::UserOnly,
522 compacted_at: None,
523 deferred_summary: None,
524 focus_pinned: false,
525 focus_marker_id: None,
526 db_id: None,
527 fidelity_tag: None,
528 embedding: None,
529 }
530 }
531
532 #[must_use]
534 pub fn focus_pinned() -> Self {
535 Self {
536 visibility: MessageVisibility::AgentOnly,
537 compacted_at: None,
538 deferred_summary: None,
539 focus_pinned: true,
540 focus_marker_id: None,
541 db_id: None,
542 fidelity_tag: None,
543 embedding: None,
544 }
545 }
546}
547
548#[derive(Clone, Debug, Serialize, Deserialize)]
575pub struct Message {
576 pub role: Role,
577 pub content: String,
579 #[serde(default)]
580 pub parts: Vec<MessagePart>,
581 #[serde(default)]
582 pub metadata: MessageMetadata,
583}
584
585impl Default for Message {
586 fn default() -> Self {
587 Self {
588 role: Role::User,
589 content: String::new(),
590 parts: vec![],
591 metadata: MessageMetadata::default(),
592 }
593 }
594}
595
596impl Message {
597 #[must_use]
602 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
603 Self {
604 role,
605 content: content.into(),
606 parts: vec![],
607 metadata: MessageMetadata::default(),
608 }
609 }
610
611 #[must_use]
616 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
617 let content = Self::flatten_parts(&parts);
618 Self {
619 role,
620 content,
621 parts,
622 metadata: MessageMetadata::default(),
623 }
624 }
625
626 #[must_use]
629 pub fn to_llm_content(&self) -> &str {
630 &self.content
631 }
632
633 pub fn rebuild_content(&mut self) {
635 if !self.parts.is_empty() {
636 self.content = Self::flatten_parts(&self.parts);
637 }
638 }
639
640 fn flatten_parts(parts: &[MessagePart]) -> String {
641 use std::fmt::Write;
642 let mut out = String::new();
643 for part in parts {
644 match part {
645 MessagePart::Text { text }
646 | MessagePart::Recall { text }
647 | MessagePart::CodeContext { text }
648 | MessagePart::Summary { text }
649 | MessagePart::CrossSession { text } => out.push_str(text),
650 MessagePart::ToolOutput {
651 tool_name,
652 body,
653 compacted_at,
654 } => {
655 if compacted_at.is_some() {
656 if body.is_empty() {
657 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
658 } else {
659 let _ = write!(out, "[tool output: {tool_name}] {body}");
660 }
661 } else {
662 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
663 }
664 }
665 MessagePart::ToolUse { id, name, .. } => {
666 let _ = write!(out, "[tool_use: {name}({id})]");
667 }
668 MessagePart::ToolResult {
669 tool_use_id,
670 content,
671 ..
672 } => {
673 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
674 }
675 MessagePart::Image(img) => {
676 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
677 }
678 MessagePart::ThinkingBlock { .. }
680 | MessagePart::RedactedThinkingBlock { .. }
681 | MessagePart::Compaction { .. } => {}
682 }
683 }
684 out
685 }
686}
687
688pub trait LlmProvider: Send + Sync {
756 fn context_window(&self) -> Option<usize> {
760 None
761 }
762
763 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
769
770 fn chat_stream(
776 &self,
777 messages: &[Message],
778 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
779
780 fn supports_streaming(&self) -> bool;
782
783 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
789
790 fn embed_batch(
800 &self,
801 texts: &[&str],
802 ) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
803 let owned = owned_strs(texts);
804 async move {
805 let mut results = Vec::with_capacity(owned.len());
806 for text in &owned {
807 results.push(self.embed(text).await?);
808 }
809 Ok(results)
810 }
811 }
812
813 fn supports_embeddings(&self) -> bool;
815
816 fn name(&self) -> &str;
818
819 #[allow(clippy::unnecessary_literal_bound)]
822 fn model_identifier(&self) -> &str {
823 ""
824 }
825
826 fn supports_vision(&self) -> bool {
828 false
829 }
830
831 fn supports_tool_use(&self) -> bool {
833 true
834 }
835
836 fn chat_with_tools(
844 &self,
845 messages: &[Message],
846 _tools: &[ToolDefinition],
847 ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
848 let msgs = messages.to_vec();
849 async move { Ok(ChatResponse::Text(self.chat(&msgs).await?)) }
850 }
851
852 fn last_cache_usage(&self) -> Option<(u64, u64)> {
855 None
856 }
857
858 fn last_usage(&self) -> Option<(u64, u64)> {
861 None
862 }
863
864 fn last_reasoning_tokens(&self) -> Option<u64> {
869 None
870 }
871
872 fn take_compaction_summary(&self) -> Option<String> {
875 None
876 }
877
878 fn chat_with_extras(
893 &self,
894 messages: &[Message],
895 ) -> impl Future<Output = Result<(String, ChatExtras), LlmError>> + Send {
896 let msgs = messages.to_vec();
897 async move { Ok((self.chat(&msgs).await?, ChatExtras::default())) }
898 }
899
900 #[must_use]
904 fn debug_request_json(
905 &self,
906 messages: &[Message],
907 tools: &[ToolDefinition],
908 _stream: bool,
909 ) -> serde_json::Value {
910 default_debug_request_json(messages, tools)
911 }
912
913 fn list_models(&self) -> Vec<String> {
916 vec![]
917 }
918
919 fn supports_structured_output(&self) -> bool {
921 false
922 }
923
924 #[allow(async_fn_in_trait)]
935 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
936 where
937 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
938 Self: Sized,
939 {
940 let (_, schema_json) = cached_schema::<T>()?;
941 let type_name = short_type_name::<T>();
942
943 let mut augmented = messages.to_vec();
944 let instruction = format!(
945 "Respond with a valid JSON object matching this schema. \
946 Output ONLY the JSON, no markdown fences or extra text.\n\n\
947 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
948 );
949 augmented.insert(0, Message::from_legacy(Role::System, instruction));
950
951 let raw = self.chat(&augmented).await?;
952 let cleaned = strip_json_fences(&raw);
953 match serde_json::from_str::<T>(cleaned) {
954 Ok(val) => Ok(val),
955 Err(first_err) => {
956 augmented.push(Message::from_legacy(Role::Assistant, &raw));
957 augmented.push(Message::from_legacy(
958 Role::User,
959 format!(
960 "Your response was not valid JSON. Error: {first_err}. \
961 Please output ONLY valid JSON matching the schema."
962 ),
963 ));
964 let retry_raw = self.chat(&augmented).await?;
965 let retry_cleaned = strip_json_fences(&retry_raw);
966 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
967 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
968 })
969 }
970 }
971 }
972}
973
974fn strip_json_fences(s: &str) -> &str {
978 s.trim()
979 .trim_start_matches("```json")
980 .trim_start_matches("```")
981 .trim_end_matches("```")
982 .trim()
983}
984
985#[cfg(test)]
986mod tests {
987 use tokio_stream::StreamExt;
988
989 use super::*;
990
991 struct StubProvider {
992 response: String,
993 }
994
995 impl LlmProvider for StubProvider {
996 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
997 Ok(self.response.clone())
998 }
999
1000 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1001 let response = self.chat(messages).await?;
1002 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1003 response,
1004 )))))
1005 }
1006
1007 fn supports_streaming(&self) -> bool {
1008 false
1009 }
1010
1011 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1012 Ok(vec![0.1, 0.2, 0.3])
1013 }
1014
1015 fn supports_embeddings(&self) -> bool {
1016 false
1017 }
1018
1019 fn name(&self) -> &'static str {
1020 "stub"
1021 }
1022 }
1023
1024 #[test]
1025 fn context_window_default_returns_none() {
1026 let provider = StubProvider {
1027 response: String::new(),
1028 };
1029 assert!(provider.context_window().is_none());
1030 }
1031
1032 #[test]
1033 fn supports_streaming_default_returns_false() {
1034 let provider = StubProvider {
1035 response: String::new(),
1036 };
1037 assert!(!provider.supports_streaming());
1038 }
1039
1040 #[tokio::test]
1041 async fn chat_stream_default_yields_single_chunk() {
1042 let provider = StubProvider {
1043 response: "hello world".into(),
1044 };
1045 let messages = vec![Message {
1046 role: Role::User,
1047 content: "test".into(),
1048 parts: vec![],
1049 metadata: MessageMetadata::default(),
1050 }];
1051
1052 let mut stream = provider.chat_stream(&messages).await.unwrap();
1053 let chunk = stream.next().await.unwrap().unwrap();
1054 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
1055 assert!(stream.next().await.is_none());
1056 }
1057
1058 #[tokio::test]
1059 async fn chat_stream_default_propagates_chat_error() {
1060 struct FailProvider;
1061
1062 impl LlmProvider for FailProvider {
1063 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1064 Err(LlmError::Unavailable)
1065 }
1066
1067 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1068 let response = self.chat(messages).await?;
1069 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1070 response,
1071 )))))
1072 }
1073
1074 fn supports_streaming(&self) -> bool {
1075 false
1076 }
1077
1078 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1079 Err(LlmError::Unavailable)
1080 }
1081
1082 fn supports_embeddings(&self) -> bool {
1083 false
1084 }
1085
1086 fn name(&self) -> &'static str {
1087 "fail"
1088 }
1089 }
1090
1091 let provider = FailProvider;
1092 let messages = vec![Message {
1093 role: Role::User,
1094 content: "test".into(),
1095 parts: vec![],
1096 metadata: MessageMetadata::default(),
1097 }];
1098
1099 let result = provider.chat_stream(&messages).await;
1100 assert!(result.is_err());
1101 if let Err(e) = result {
1102 assert!(e.to_string().contains("provider unavailable"));
1103 }
1104 }
1105
1106 #[tokio::test]
1107 async fn stub_provider_embed_returns_vector() {
1108 let provider = StubProvider {
1109 response: String::new(),
1110 };
1111 let embedding = provider.embed("test").await.unwrap();
1112 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
1113 }
1114
1115 #[tokio::test]
1116 async fn fail_provider_embed_propagates_error() {
1117 struct FailProvider;
1118
1119 impl LlmProvider for FailProvider {
1120 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1121 Err(LlmError::Unavailable)
1122 }
1123
1124 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1125 let response = self.chat(messages).await?;
1126 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1127 response,
1128 )))))
1129 }
1130
1131 fn supports_streaming(&self) -> bool {
1132 false
1133 }
1134
1135 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1136 Err(LlmError::EmbedUnsupported {
1137 provider: "fail".into(),
1138 })
1139 }
1140
1141 fn supports_embeddings(&self) -> bool {
1142 false
1143 }
1144
1145 fn name(&self) -> &'static str {
1146 "fail"
1147 }
1148 }
1149
1150 let provider = FailProvider;
1151 let result = provider.embed("test").await;
1152 assert!(result.is_err());
1153 assert!(
1154 result
1155 .unwrap_err()
1156 .to_string()
1157 .contains("embedding not supported")
1158 );
1159 }
1160
1161 #[test]
1162 fn role_serialization() {
1163 let system = Role::System;
1164 let user = Role::User;
1165 let assistant = Role::Assistant;
1166
1167 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
1168 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
1169 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
1170 }
1171
1172 #[test]
1173 fn role_deserialization() {
1174 let system: Role = serde_json::from_str("\"system\"").unwrap();
1175 let user: Role = serde_json::from_str("\"user\"").unwrap();
1176 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
1177
1178 assert_eq!(system, Role::System);
1179 assert_eq!(user, Role::User);
1180 assert_eq!(assistant, Role::Assistant);
1181 }
1182
1183 #[test]
1184 fn message_clone() {
1185 let msg = Message {
1186 role: Role::User,
1187 content: "test".into(),
1188 parts: vec![],
1189 metadata: MessageMetadata::default(),
1190 };
1191 let cloned = msg.clone();
1192 assert_eq!(cloned.role, msg.role);
1193 assert_eq!(cloned.content, msg.content);
1194 }
1195
1196 #[test]
1197 fn message_debug() {
1198 let msg = Message {
1199 role: Role::Assistant,
1200 content: "response".into(),
1201 parts: vec![],
1202 metadata: MessageMetadata::default(),
1203 };
1204 let debug = format!("{msg:?}");
1205 assert!(debug.contains("Assistant"));
1206 assert!(debug.contains("response"));
1207 }
1208
1209 #[test]
1210 fn message_serialization() {
1211 let msg = Message {
1212 role: Role::User,
1213 content: "hello".into(),
1214 parts: vec![],
1215 metadata: MessageMetadata::default(),
1216 };
1217 let json = serde_json::to_string(&msg).unwrap();
1218 assert!(json.contains("\"role\":\"user\""));
1219 assert!(json.contains("\"content\":\"hello\""));
1220 }
1221
1222 #[test]
1223 fn message_part_serde_round_trip() {
1224 let parts = vec![
1225 MessagePart::Text {
1226 text: "hello".into(),
1227 },
1228 MessagePart::ToolOutput {
1229 tool_name: "bash".into(),
1230 body: "output".into(),
1231 compacted_at: None,
1232 },
1233 MessagePart::Recall {
1234 text: "recall".into(),
1235 },
1236 MessagePart::CodeContext {
1237 text: "code".into(),
1238 },
1239 MessagePart::Summary {
1240 text: "summary".into(),
1241 },
1242 ];
1243 let json = serde_json::to_string(&parts).unwrap();
1244 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
1245 assert_eq!(deserialized.len(), 5);
1246 }
1247
1248 #[test]
1249 fn from_legacy_creates_empty_parts() {
1250 let msg = Message::from_legacy(Role::User, "hello");
1251 assert_eq!(msg.role, Role::User);
1252 assert_eq!(msg.content, "hello");
1253 assert!(msg.parts.is_empty());
1254 assert_eq!(msg.to_llm_content(), "hello");
1255 }
1256
1257 #[test]
1258 fn from_parts_flattens_content() {
1259 let msg = Message::from_parts(
1260 Role::System,
1261 vec![MessagePart::Recall {
1262 text: "recalled data".into(),
1263 }],
1264 );
1265 assert_eq!(msg.content, "recalled data");
1266 assert_eq!(msg.to_llm_content(), "recalled data");
1267 assert_eq!(msg.parts.len(), 1);
1268 }
1269
1270 #[test]
1271 fn from_parts_tool_output_format() {
1272 let msg = Message::from_parts(
1273 Role::User,
1274 vec![MessagePart::ToolOutput {
1275 tool_name: "bash".into(),
1276 body: "hello world".into(),
1277 compacted_at: None,
1278 }],
1279 );
1280 assert!(msg.content.contains("[tool output: bash]"));
1281 assert!(msg.content.contains("hello world"));
1282 }
1283
1284 #[test]
1285 fn message_deserializes_without_parts() {
1286 let json = r#"{"role":"user","content":"hello"}"#;
1287 let msg: Message = serde_json::from_str(json).unwrap();
1288 assert_eq!(msg.content, "hello");
1289 assert!(msg.parts.is_empty());
1290 }
1291
1292 #[test]
1293 fn flatten_skips_compacted_tool_output_empty_body() {
1294 let msg = Message::from_parts(
1296 Role::User,
1297 vec![
1298 MessagePart::Text {
1299 text: "prefix ".into(),
1300 },
1301 MessagePart::ToolOutput {
1302 tool_name: "bash".into(),
1303 body: String::new(),
1304 compacted_at: Some(1234),
1305 },
1306 MessagePart::Text {
1307 text: " suffix".into(),
1308 },
1309 ],
1310 );
1311 assert!(msg.content.contains("(pruned)"));
1312 assert!(msg.content.contains("prefix "));
1313 assert!(msg.content.contains(" suffix"));
1314 }
1315
1316 #[test]
1317 fn flatten_compacted_tool_output_with_reference_renders_body() {
1318 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
1320 let msg = Message::from_parts(
1321 Role::User,
1322 vec![MessagePart::ToolOutput {
1323 tool_name: "bash".into(),
1324 body: ref_notice.into(),
1325 compacted_at: Some(1234),
1326 }],
1327 );
1328 assert!(msg.content.contains(ref_notice));
1329 assert!(!msg.content.contains("(pruned)"));
1330 }
1331
1332 #[test]
1333 fn rebuild_content_syncs_after_mutation() {
1334 let mut msg = Message::from_parts(
1335 Role::User,
1336 vec![MessagePart::ToolOutput {
1337 tool_name: "bash".into(),
1338 body: "original".into(),
1339 compacted_at: None,
1340 }],
1341 );
1342 assert!(msg.content.contains("original"));
1343
1344 if let MessagePart::ToolOutput {
1345 ref mut compacted_at,
1346 ref mut body,
1347 ..
1348 } = msg.parts[0]
1349 {
1350 *compacted_at = Some(999);
1351 body.clear(); }
1353 msg.rebuild_content();
1354
1355 assert!(msg.content.contains("(pruned)"));
1356 assert!(!msg.content.contains("original"));
1357 }
1358
1359 #[test]
1360 fn message_part_tool_use_serde_round_trip() {
1361 let part = MessagePart::ToolUse {
1362 id: "toolu_123".into(),
1363 name: "bash".into(),
1364 input: serde_json::json!({"command": "ls"}),
1365 };
1366 let json = serde_json::to_string(&part).unwrap();
1367 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1368 if let MessagePart::ToolUse { id, name, input } = deserialized {
1369 assert_eq!(id, "toolu_123");
1370 assert_eq!(name, "bash");
1371 assert_eq!(input["command"], "ls");
1372 } else {
1373 panic!("expected ToolUse");
1374 }
1375 }
1376
1377 #[test]
1378 fn message_part_tool_result_serde_round_trip() {
1379 let part = MessagePart::ToolResult {
1380 tool_use_id: "toolu_123".into(),
1381 content: "file1.rs\nfile2.rs".into(),
1382 is_error: false,
1383 };
1384 let json = serde_json::to_string(&part).unwrap();
1385 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1386 if let MessagePart::ToolResult {
1387 tool_use_id,
1388 content,
1389 is_error,
1390 } = deserialized
1391 {
1392 assert_eq!(tool_use_id, "toolu_123");
1393 assert_eq!(content, "file1.rs\nfile2.rs");
1394 assert!(!is_error);
1395 } else {
1396 panic!("expected ToolResult");
1397 }
1398 }
1399
1400 #[test]
1401 fn message_part_tool_result_is_error_default() {
1402 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1403 let part: MessagePart = serde_json::from_str(json).unwrap();
1404 if let MessagePart::ToolResult { is_error, .. } = part {
1405 assert!(!is_error);
1406 } else {
1407 panic!("expected ToolResult");
1408 }
1409 }
1410
1411 #[test]
1412 fn chat_response_construction() {
1413 let text = ChatResponse::Text("hello".into());
1414 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1415
1416 let tool_use = ChatResponse::ToolUse {
1417 text: Some("I'll run that".into()),
1418 tool_calls: vec![ToolUseRequest {
1419 id: "1".into(),
1420 name: "bash".into(),
1421 input: serde_json::json!({}),
1422 }],
1423 thinking_blocks: vec![],
1424 };
1425 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1426 }
1427
1428 #[test]
1429 fn flatten_parts_tool_use() {
1430 let msg = Message::from_parts(
1431 Role::Assistant,
1432 vec![MessagePart::ToolUse {
1433 id: "t1".into(),
1434 name: "bash".into(),
1435 input: serde_json::json!({"command": "ls"}),
1436 }],
1437 );
1438 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1439 }
1440
1441 #[test]
1442 fn flatten_parts_tool_result() {
1443 let msg = Message::from_parts(
1444 Role::User,
1445 vec![MessagePart::ToolResult {
1446 tool_use_id: "t1".into(),
1447 content: "output here".into(),
1448 is_error: false,
1449 }],
1450 );
1451 assert!(msg.content.contains("[tool_result: t1]"));
1452 assert!(msg.content.contains("output here"));
1453 }
1454
1455 #[test]
1456 fn tool_definition_serde_round_trip() {
1457 let def = ToolDefinition {
1458 name: "bash".into(),
1459 description: "Execute a shell command".into(),
1460 parameters: serde_json::json!({"type": "object"}),
1461 output_schema: None,
1462 };
1463 let json = serde_json::to_string(&def).unwrap();
1464 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1465 assert_eq!(deserialized.name, "bash");
1466 assert_eq!(deserialized.description, "Execute a shell command");
1467 }
1468
1469 #[tokio::test]
1470 async fn chat_with_tools_default_delegates_to_chat() {
1471 let provider = StubProvider {
1472 response: "hello".into(),
1473 };
1474 let messages = vec![Message::from_legacy(Role::User, "test")];
1475 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1476 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1477 }
1478
1479 #[test]
1480 fn tool_output_compacted_at_serde_default() {
1481 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1482 let part: MessagePart = serde_json::from_str(json).unwrap();
1483 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1484 assert!(compacted_at.is_none());
1485 } else {
1486 panic!("expected ToolOutput");
1487 }
1488 }
1489
1490 #[test]
1493 fn strip_json_fences_plain_json() {
1494 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1495 }
1496
1497 #[test]
1498 fn strip_json_fences_with_json_fence() {
1499 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1500 }
1501
1502 #[test]
1503 fn strip_json_fences_with_plain_fence() {
1504 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1505 }
1506
1507 #[test]
1508 fn strip_json_fences_whitespace() {
1509 assert_eq!(strip_json_fences(" \n "), "");
1510 }
1511
1512 #[test]
1513 fn strip_json_fences_empty() {
1514 assert_eq!(strip_json_fences(""), "");
1515 }
1516
1517 #[test]
1518 fn strip_json_fences_outer_whitespace() {
1519 assert_eq!(
1520 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1521 r#"{"a": 1}"#
1522 );
1523 }
1524
1525 #[test]
1526 fn strip_json_fences_only_opening_fence() {
1527 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1528 }
1529
1530 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1533 struct TestOutput {
1534 value: String,
1535 }
1536
1537 struct SequentialStub {
1538 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1539 }
1540
1541 impl SequentialStub {
1542 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1543 Self {
1544 responses: std::sync::Mutex::new(responses),
1545 }
1546 }
1547 }
1548
1549 impl LlmProvider for SequentialStub {
1550 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1551 let mut responses = self.responses.lock().unwrap();
1552 if responses.is_empty() {
1553 return Err(LlmError::Other("no more responses".into()));
1554 }
1555 responses.remove(0)
1556 }
1557
1558 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1559 let response = self.chat(messages).await?;
1560 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1561 response,
1562 )))))
1563 }
1564
1565 fn supports_streaming(&self) -> bool {
1566 false
1567 }
1568
1569 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1570 Err(LlmError::EmbedUnsupported {
1571 provider: "sequential-stub".into(),
1572 })
1573 }
1574
1575 fn supports_embeddings(&self) -> bool {
1576 false
1577 }
1578
1579 fn name(&self) -> &'static str {
1580 "sequential-stub"
1581 }
1582 }
1583
1584 #[tokio::test]
1585 async fn chat_typed_happy_path() {
1586 let provider = StubProvider {
1587 response: r#"{"value": "hello"}"#.into(),
1588 };
1589 let messages = vec![Message::from_legacy(Role::User, "test")];
1590 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1591 assert_eq!(
1592 result,
1593 TestOutput {
1594 value: "hello".into()
1595 }
1596 );
1597 }
1598
1599 #[tokio::test]
1600 async fn chat_typed_retry_succeeds() {
1601 let provider = SequentialStub::new(vec![
1602 Ok("not valid json".into()),
1603 Ok(r#"{"value": "ok"}"#.into()),
1604 ]);
1605 let messages = vec![Message::from_legacy(Role::User, "test")];
1606 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1607 assert_eq!(result, TestOutput { value: "ok".into() });
1608 }
1609
1610 #[tokio::test]
1611 async fn chat_typed_both_fail() {
1612 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1613 let messages = vec![Message::from_legacy(Role::User, "test")];
1614 let result = provider.chat_typed::<TestOutput>(&messages).await;
1615 let err = result.unwrap_err();
1616 assert!(err.to_string().contains("parse failed after retry"));
1617 }
1618
1619 #[tokio::test]
1620 async fn chat_typed_chat_error_propagates() {
1621 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1622 let messages = vec![Message::from_legacy(Role::User, "test")];
1623 let result = provider.chat_typed::<TestOutput>(&messages).await;
1624 assert!(matches!(result, Err(LlmError::Unavailable)));
1625 }
1626
1627 #[tokio::test]
1628 async fn chat_typed_strips_fences() {
1629 let provider = StubProvider {
1630 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1631 };
1632 let messages = vec![Message::from_legacy(Role::User, "test")];
1633 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1634 assert_eq!(
1635 result,
1636 TestOutput {
1637 value: "fenced".into()
1638 }
1639 );
1640 }
1641
1642 #[test]
1643 fn supports_structured_output_default_false() {
1644 let provider = StubProvider {
1645 response: String::new(),
1646 };
1647 assert!(!provider.supports_structured_output());
1648 }
1649
1650 #[test]
1651 fn structured_parse_error_display() {
1652 let err = LlmError::StructuredParse("test error".into());
1653 assert_eq!(
1654 err.to_string(),
1655 "structured output parse failed: test error"
1656 );
1657 }
1658
1659 #[test]
1660 fn message_part_image_roundtrip_json() {
1661 let part = MessagePart::Image(Box::new(ImageData {
1662 data: vec![1, 2, 3, 4],
1663 mime_type: "image/jpeg".into(),
1664 }));
1665 let json = serde_json::to_string(&part).unwrap();
1666 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1667 match decoded {
1668 MessagePart::Image(img) => {
1669 assert_eq!(img.data, vec![1, 2, 3, 4]);
1670 assert_eq!(img.mime_type, "image/jpeg");
1671 }
1672 _ => panic!("expected Image variant"),
1673 }
1674 }
1675
1676 #[test]
1677 fn flatten_parts_includes_image_placeholder() {
1678 let msg = Message::from_parts(
1679 Role::User,
1680 vec![
1681 MessagePart::Text {
1682 text: "see this".into(),
1683 },
1684 MessagePart::Image(Box::new(ImageData {
1685 data: vec![0u8; 100],
1686 mime_type: "image/png".into(),
1687 })),
1688 ],
1689 );
1690 let content = msg.to_llm_content();
1691 assert!(content.contains("see this"));
1692 assert!(content.contains("[image: image/png"));
1693 }
1694
1695 #[test]
1696 fn supports_vision_default_false() {
1697 let provider = StubProvider {
1698 response: String::new(),
1699 };
1700 assert!(!provider.supports_vision());
1701 }
1702
1703 #[test]
1704 fn message_metadata_default_both_visible() {
1705 let m = MessageMetadata::default();
1706 assert!(m.visibility.is_agent_visible());
1707 assert!(m.visibility.is_user_visible());
1708 assert_eq!(m.visibility, MessageVisibility::Both);
1709 assert!(m.compacted_at.is_none());
1710 }
1711
1712 #[test]
1713 fn message_metadata_agent_only() {
1714 let m = MessageMetadata::agent_only();
1715 assert!(m.visibility.is_agent_visible());
1716 assert!(!m.visibility.is_user_visible());
1717 assert_eq!(m.visibility, MessageVisibility::AgentOnly);
1718 }
1719
1720 #[test]
1721 fn message_metadata_user_only() {
1722 let m = MessageMetadata::user_only();
1723 assert!(!m.visibility.is_agent_visible());
1724 assert!(m.visibility.is_user_visible());
1725 assert_eq!(m.visibility, MessageVisibility::UserOnly);
1726 }
1727
1728 #[test]
1729 fn message_metadata_serde_default() {
1730 let json = r#"{"role":"user","content":"hello"}"#;
1731 let msg: Message = serde_json::from_str(json).unwrap();
1732 assert!(msg.metadata.visibility.is_agent_visible());
1733 assert!(msg.metadata.visibility.is_user_visible());
1734 }
1735
1736 #[test]
1737 fn message_metadata_round_trip() {
1738 let msg = Message {
1739 role: Role::User,
1740 content: "test".into(),
1741 parts: vec![],
1742 metadata: MessageMetadata::agent_only(),
1743 };
1744 let json = serde_json::to_string(&msg).unwrap();
1745 let decoded: Message = serde_json::from_str(&json).unwrap();
1746 assert!(decoded.metadata.visibility.is_agent_visible());
1747 assert!(!decoded.metadata.visibility.is_user_visible());
1748 assert_eq!(decoded.metadata.visibility, MessageVisibility::AgentOnly);
1749 }
1750
1751 #[test]
1752 fn message_part_compaction_round_trip() {
1753 let part = MessagePart::Compaction {
1754 summary: "Context was summarized.".to_owned(),
1755 };
1756 let json = serde_json::to_string(&part).unwrap();
1757 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1758 assert!(
1759 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1760 );
1761 }
1762
1763 #[test]
1764 fn flatten_parts_compaction_contributes_no_text() {
1765 let parts = vec![
1768 MessagePart::Text {
1769 text: "Hello".to_owned(),
1770 },
1771 MessagePart::Compaction {
1772 summary: "Summary".to_owned(),
1773 },
1774 ];
1775 let msg = Message::from_parts(Role::Assistant, parts);
1776 assert_eq!(msg.content.trim(), "Hello");
1778 }
1779
1780 #[test]
1781 fn stream_chunk_compaction_variant() {
1782 let chunk = StreamChunk::Compaction("A summary".to_owned());
1783 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1784 }
1785
1786 #[test]
1787 fn short_type_name_extracts_last_segment() {
1788 struct MyOutput;
1789 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1790 }
1791
1792 #[test]
1793 fn short_type_name_primitive_returns_full_name() {
1794 assert_eq!(short_type_name::<u32>(), "u32");
1796 assert_eq!(short_type_name::<bool>(), "bool");
1797 }
1798
1799 #[test]
1800 fn short_type_name_nested_path_returns_last() {
1801 assert_eq!(
1803 short_type_name::<std::collections::HashMap<u32, u32>>(),
1804 "HashMap<u32, u32>"
1805 );
1806 }
1807
1808 #[test]
1811 fn summary_roundtrip() {
1812 let part = MessagePart::Summary {
1813 text: "hello".to_string(),
1814 };
1815 let json = serde_json::to_string(&part).expect("serialization must not fail");
1816 assert!(
1817 json.contains("\"kind\":\"summary\""),
1818 "must use internally-tagged format, got: {json}"
1819 );
1820 assert!(
1821 !json.contains("\"Summary\""),
1822 "must not use externally-tagged format, got: {json}"
1823 );
1824 let decoded: MessagePart =
1825 serde_json::from_str(&json).expect("deserialization must not fail");
1826 match decoded {
1827 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1828 other => panic!("expected MessagePart::Summary, got {other:?}"),
1829 }
1830 }
1831
1832 #[tokio::test]
1833 async fn embed_batch_default_empty_returns_empty() {
1834 let provider = StubProvider {
1835 response: String::new(),
1836 };
1837 let result = provider.embed_batch(&[]).await.unwrap();
1838 assert!(result.is_empty());
1839 }
1840
1841 #[tokio::test]
1842 async fn embed_batch_default_calls_embed_sequentially() {
1843 let provider = StubProvider {
1844 response: String::new(),
1845 };
1846 let texts = ["hello", "world", "foo"];
1847 let result = provider.embed_batch(&texts).await.unwrap();
1848 assert_eq!(result.len(), 3);
1849 for vec in &result {
1851 assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
1852 }
1853 }
1854
1855 #[test]
1856 fn message_visibility_db_roundtrip_both() {
1857 assert_eq!(MessageVisibility::Both.as_db_str(), "both");
1858 assert_eq!(
1859 MessageVisibility::from_db_str("both"),
1860 MessageVisibility::Both
1861 );
1862 }
1863
1864 #[test]
1865 fn message_visibility_db_roundtrip_agent_only() {
1866 assert_eq!(MessageVisibility::AgentOnly.as_db_str(), "agent_only");
1867 assert_eq!(
1868 MessageVisibility::from_db_str("agent_only"),
1869 MessageVisibility::AgentOnly
1870 );
1871 }
1872
1873 #[test]
1874 fn message_visibility_db_roundtrip_user_only() {
1875 assert_eq!(MessageVisibility::UserOnly.as_db_str(), "user_only");
1876 assert_eq!(
1877 MessageVisibility::from_db_str("user_only"),
1878 MessageVisibility::UserOnly
1879 );
1880 }
1881
1882 #[test]
1883 fn message_visibility_from_db_str_unknown_defaults_to_both() {
1884 assert_eq!(
1885 MessageVisibility::from_db_str("unknown_future_value"),
1886 MessageVisibility::Both
1887 );
1888 assert_eq!(MessageVisibility::from_db_str(""), MessageVisibility::Both);
1889 }
1890}