1use std::future::Future;
5use std::pin::Pin;
6use std::{
7 any::TypeId,
8 collections::HashMap,
9 sync::{LazyLock, Mutex},
10};
11
12use futures_core::Stream;
13use serde::{Deserialize, Serialize};
14
15use zeph_common::ToolName;
16
17pub use zeph_common::ToolDefinition;
18
19use crate::embed::owned_strs;
20use crate::error::LlmError;
21
22static SCHEMA_CACHE: LazyLock<Mutex<HashMap<TypeId, (serde_json::Value, String)>>> =
23 LazyLock::new(|| Mutex::new(HashMap::new()));
24
25pub(crate) fn cached_schema<T: schemars::JsonSchema + 'static>()
31-> Result<(serde_json::Value, String), crate::LlmError> {
32 let type_id = TypeId::of::<T>();
33 if let Ok(cache) = SCHEMA_CACHE.lock()
34 && let Some(entry) = cache.get(&type_id)
35 {
36 return Ok(entry.clone());
37 }
38 let schema = schemars::schema_for!(T);
39 let value = serde_json::to_value(&schema)
40 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
41 let pretty = serde_json::to_string_pretty(&schema)
42 .map_err(|e| crate::LlmError::StructuredParse(e.to_string()))?;
43 if let Ok(mut cache) = SCHEMA_CACHE.lock() {
44 cache.insert(type_id, (value.clone(), pretty.clone()));
45 }
46 Ok((value, pretty))
47}
48
49pub(crate) fn short_type_name<T: ?Sized>() -> &'static str {
63 std::any::type_name::<T>()
64 .rsplit("::")
65 .next()
66 .unwrap_or("Output")
67}
68
69#[non_exhaustive]
78#[derive(Debug, Clone, Default)]
79pub struct ChatExtras {
80 pub entropy: Option<f64>,
85}
86
87impl ChatExtras {
88 #[must_use]
101 pub fn with_entropy(entropy: f64) -> Self {
102 Self {
103 entropy: Some(entropy),
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
113pub enum StreamChunk {
114 Content(String),
116 Thinking(String),
118 Compaction(String),
121 ToolUse(Vec<ToolUseRequest>),
123}
124
125pub type ChatStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ToolUseRequest {
138 pub id: String,
140 pub name: ToolName,
142 pub input: serde_json::Value,
144}
145
146#[derive(Debug, Clone)]
152pub enum ThinkingBlock {
153 Thinking { thinking: String, signature: String },
155 Redacted { data: String },
157}
158
159pub const MAX_TOKENS_TRUNCATION_MARKER: &str = "max_tokens limit reached";
162
163#[derive(Debug, Clone)]
171pub enum ChatResponse {
172 Text(String),
174 ToolUse {
176 text: Option<String>,
178 tool_calls: Vec<ToolUseRequest>,
179 thinking_blocks: Vec<ThinkingBlock>,
182 },
183}
184
185pub type EmbedFuture = Pin<Box<dyn Future<Output = Result<Vec<f32>, LlmError>> + Send>>;
187
188pub type EmbedFn = Box<dyn Fn(&str) -> EmbedFuture + Send + Sync>;
193
194pub type StatusTx = tokio::sync::mpsc::UnboundedSender<String>;
200
201#[must_use]
204pub fn default_debug_request_json(
205 messages: &[Message],
206 tools: &[ToolDefinition],
207) -> serde_json::Value {
208 serde_json::json!({
209 "model": serde_json::Value::Null,
210 "max_tokens": serde_json::Value::Null,
211 "messages": serde_json::to_value(messages).unwrap_or(serde_json::Value::Array(vec![])),
212 "tools": serde_json::to_value(tools).unwrap_or(serde_json::Value::Array(vec![])),
213 "temperature": serde_json::Value::Null,
214 "cache_control": serde_json::Value::Null,
215 })
216}
217
218#[derive(Debug, Clone, Default)]
227pub struct GenerationOverrides {
228 pub temperature: Option<f64>,
230 pub top_p: Option<f64>,
232 pub top_k: Option<usize>,
234 pub frequency_penalty: Option<f64>,
236 pub presence_penalty: Option<f64>,
238}
239
240#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
247#[serde(rename_all = "lowercase")]
248pub enum Role {
249 System,
250 User,
251 Assistant,
252}
253
254#[derive(Clone, Debug, Serialize, Deserialize)]
269#[serde(tag = "kind", rename_all = "snake_case")]
270pub enum MessagePart {
271 Text { text: String },
273 ToolOutput {
275 tool_name: zeph_common::ToolName,
276 body: String,
277 #[serde(default, skip_serializing_if = "Option::is_none")]
278 compacted_at: Option<i64>,
279 },
280 Recall { text: String },
282 CodeContext { text: String },
284 Summary { text: String },
286 CrossSession { text: String },
288 ToolUse {
290 id: String,
291 name: String,
292 input: serde_json::Value,
293 },
294 ToolResult {
296 tool_use_id: String,
297 content: String,
298 #[serde(default)]
299 is_error: bool,
300 },
301 Image(Box<ImageData>),
303 ThinkingBlock { thinking: String, signature: String },
305 RedactedThinkingBlock { data: String },
307 Compaction { summary: String },
310}
311
312impl MessagePart {
313 #[must_use]
316 pub fn as_plain_text(&self) -> Option<&str> {
317 match self {
318 Self::Text { text }
319 | Self::Recall { text }
320 | Self::CodeContext { text }
321 | Self::Summary { text }
322 | Self::CrossSession { text } => Some(text.as_str()),
323 _ => None,
324 }
325 }
326
327 #[must_use]
329 pub fn as_image(&self) -> Option<&ImageData> {
330 if let Self::Image(img) = self {
331 Some(img)
332 } else {
333 None
334 }
335 }
336}
337
338#[derive(Clone, Debug, Serialize, Deserialize)]
339pub struct ImageData {
344 #[serde(with = "serde_bytes_base64")]
345 pub data: Vec<u8>,
346 pub mime_type: String,
347}
348
349mod serde_bytes_base64 {
350 use base64::{Engine, engine::general_purpose::STANDARD};
351 use serde::{Deserialize, Deserializer, Serializer};
352
353 pub fn serialize<S>(bytes: &[u8], s: S) -> Result<S::Ok, S::Error>
354 where
355 S: Serializer,
356 {
357 s.serialize_str(&STANDARD.encode(bytes))
358 }
359
360 pub fn deserialize<'de, D>(d: D) -> Result<Vec<u8>, D::Error>
361 where
362 D: Deserializer<'de>,
363 {
364 let s = String::deserialize(d)?;
365 STANDARD.decode(&s).map_err(serde::de::Error::custom)
366 }
367}
368
369#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
385#[serde(rename_all = "snake_case")]
386pub enum MessageVisibility {
387 Both,
389 AgentOnly,
391 UserOnly,
393}
394
395impl MessageVisibility {
396 #[must_use]
398 pub fn is_agent_visible(self) -> bool {
399 matches!(self, MessageVisibility::Both | MessageVisibility::AgentOnly)
400 }
401
402 #[must_use]
404 pub fn is_user_visible(self) -> bool {
405 matches!(self, MessageVisibility::Both | MessageVisibility::UserOnly)
406 }
407}
408
409impl Default for MessageVisibility {
410 fn default() -> Self {
412 MessageVisibility::Both
413 }
414}
415
416impl MessageVisibility {
417 #[must_use]
419 pub fn as_db_str(self) -> &'static str {
420 match self {
421 MessageVisibility::Both => "both",
422 MessageVisibility::AgentOnly => "agent_only",
423 MessageVisibility::UserOnly => "user_only",
424 }
425 }
426
427 #[must_use]
431 pub fn from_db_str(s: &str) -> Self {
432 match s {
433 "agent_only" => MessageVisibility::AgentOnly,
434 "user_only" => MessageVisibility::UserOnly,
435 _ => MessageVisibility::Both,
436 }
437 }
438}
439
440#[derive(Clone, Debug, Serialize, Deserialize)]
445pub struct MessageMetadata {
446 pub visibility: MessageVisibility,
448 #[serde(default, skip_serializing_if = "Option::is_none")]
450 pub compacted_at: Option<i64>,
451 #[serde(default, skip_serializing_if = "Option::is_none")]
454 pub deferred_summary: Option<String>,
455 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
458 pub focus_pinned: bool,
459 #[serde(default, skip_serializing_if = "Option::is_none")]
462 pub focus_marker_id: Option<uuid::Uuid>,
463 #[serde(skip)]
466 pub db_id: Option<i64>,
467}
468
469impl Default for MessageMetadata {
470 fn default() -> Self {
471 Self {
472 visibility: MessageVisibility::Both,
473 compacted_at: None,
474 deferred_summary: None,
475 focus_pinned: false,
476 focus_marker_id: None,
477 db_id: None,
478 }
479 }
480}
481
482impl MessageMetadata {
483 #[must_use]
485 pub fn agent_only() -> Self {
486 Self {
487 visibility: MessageVisibility::AgentOnly,
488 compacted_at: None,
489 deferred_summary: None,
490 focus_pinned: false,
491 focus_marker_id: None,
492 db_id: None,
493 }
494 }
495
496 #[must_use]
498 pub fn user_only() -> Self {
499 Self {
500 visibility: MessageVisibility::UserOnly,
501 compacted_at: None,
502 deferred_summary: None,
503 focus_pinned: false,
504 focus_marker_id: None,
505 db_id: None,
506 }
507 }
508
509 #[must_use]
511 pub fn focus_pinned() -> Self {
512 Self {
513 visibility: MessageVisibility::AgentOnly,
514 compacted_at: None,
515 deferred_summary: None,
516 focus_pinned: true,
517 focus_marker_id: None,
518 db_id: None,
519 }
520 }
521}
522
523#[derive(Clone, Debug, Serialize, Deserialize)]
550pub struct Message {
551 pub role: Role,
552 pub content: String,
554 #[serde(default)]
555 pub parts: Vec<MessagePart>,
556 #[serde(default)]
557 pub metadata: MessageMetadata,
558}
559
560impl Default for Message {
561 fn default() -> Self {
562 Self {
563 role: Role::User,
564 content: String::new(),
565 parts: vec![],
566 metadata: MessageMetadata::default(),
567 }
568 }
569}
570
571impl Message {
572 #[must_use]
577 pub fn from_legacy(role: Role, content: impl Into<String>) -> Self {
578 Self {
579 role,
580 content: content.into(),
581 parts: vec![],
582 metadata: MessageMetadata::default(),
583 }
584 }
585
586 #[must_use]
591 pub fn from_parts(role: Role, parts: Vec<MessagePart>) -> Self {
592 let content = Self::flatten_parts(&parts);
593 Self {
594 role,
595 content,
596 parts,
597 metadata: MessageMetadata::default(),
598 }
599 }
600
601 #[must_use]
604 pub fn to_llm_content(&self) -> &str {
605 &self.content
606 }
607
608 pub fn rebuild_content(&mut self) {
610 if !self.parts.is_empty() {
611 self.content = Self::flatten_parts(&self.parts);
612 }
613 }
614
615 fn flatten_parts(parts: &[MessagePart]) -> String {
616 use std::fmt::Write;
617 let mut out = String::new();
618 for part in parts {
619 match part {
620 MessagePart::Text { text }
621 | MessagePart::Recall { text }
622 | MessagePart::CodeContext { text }
623 | MessagePart::Summary { text }
624 | MessagePart::CrossSession { text } => out.push_str(text),
625 MessagePart::ToolOutput {
626 tool_name,
627 body,
628 compacted_at,
629 } => {
630 if compacted_at.is_some() {
631 if body.is_empty() {
632 let _ = write!(out, "[tool output: {tool_name}] (pruned)");
633 } else {
634 let _ = write!(out, "[tool output: {tool_name}] {body}");
635 }
636 } else {
637 let _ = write!(out, "[tool output: {tool_name}]\n```\n{body}\n```");
638 }
639 }
640 MessagePart::ToolUse { id, name, .. } => {
641 let _ = write!(out, "[tool_use: {name}({id})]");
642 }
643 MessagePart::ToolResult {
644 tool_use_id,
645 content,
646 ..
647 } => {
648 let _ = write!(out, "[tool_result: {tool_use_id}]\n{content}");
649 }
650 MessagePart::Image(img) => {
651 let _ = write!(out, "[image: {}, {} bytes]", img.mime_type, img.data.len());
652 }
653 MessagePart::ThinkingBlock { .. }
655 | MessagePart::RedactedThinkingBlock { .. }
656 | MessagePart::Compaction { .. } => {}
657 }
658 }
659 out
660 }
661}
662
663pub trait LlmProvider: Send + Sync {
717 fn context_window(&self) -> Option<usize> {
721 None
722 }
723
724 fn chat(&self, messages: &[Message]) -> impl Future<Output = Result<String, LlmError>> + Send;
730
731 fn chat_stream(
737 &self,
738 messages: &[Message],
739 ) -> impl Future<Output = Result<ChatStream, LlmError>> + Send;
740
741 fn supports_streaming(&self) -> bool;
743
744 fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, LlmError>> + Send;
750
751 fn embed_batch(
761 &self,
762 texts: &[&str],
763 ) -> impl Future<Output = Result<Vec<Vec<f32>>, LlmError>> + Send {
764 let owned = owned_strs(texts);
765 async move {
766 let mut results = Vec::with_capacity(owned.len());
767 for text in &owned {
768 results.push(self.embed(text).await?);
769 }
770 Ok(results)
771 }
772 }
773
774 fn supports_embeddings(&self) -> bool;
776
777 fn name(&self) -> &str;
779
780 #[allow(clippy::unnecessary_literal_bound)]
783 fn model_identifier(&self) -> &str {
784 ""
785 }
786
787 fn supports_vision(&self) -> bool {
789 false
790 }
791
792 fn supports_tool_use(&self) -> bool {
794 true
795 }
796
797 fn chat_with_tools(
805 &self,
806 messages: &[Message],
807 _tools: &[ToolDefinition],
808 ) -> impl std::future::Future<Output = Result<ChatResponse, LlmError>> + Send {
809 let msgs = messages.to_vec();
810 async move { Ok(ChatResponse::Text(self.chat(&msgs).await?)) }
811 }
812
813 fn last_cache_usage(&self) -> Option<(u64, u64)> {
816 None
817 }
818
819 fn last_usage(&self) -> Option<(u64, u64)> {
822 None
823 }
824
825 fn take_compaction_summary(&self) -> Option<String> {
828 None
829 }
830
831 fn chat_with_extras(
846 &self,
847 messages: &[Message],
848 ) -> impl Future<Output = Result<(String, ChatExtras), LlmError>> + Send {
849 let msgs = messages.to_vec();
850 async move { Ok((self.chat(&msgs).await?, ChatExtras::default())) }
851 }
852
853 #[must_use]
857 fn debug_request_json(
858 &self,
859 messages: &[Message],
860 tools: &[ToolDefinition],
861 _stream: bool,
862 ) -> serde_json::Value {
863 default_debug_request_json(messages, tools)
864 }
865
866 fn list_models(&self) -> Vec<String> {
869 vec![]
870 }
871
872 fn supports_structured_output(&self) -> bool {
874 false
875 }
876
877 #[allow(async_fn_in_trait)]
882 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
883 where
884 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
885 Self: Sized,
886 {
887 let (_, schema_json) = cached_schema::<T>()?;
888 let type_name = short_type_name::<T>();
889
890 let mut augmented = messages.to_vec();
891 let instruction = format!(
892 "Respond with a valid JSON object matching this schema. \
893 Output ONLY the JSON, no markdown fences or extra text.\n\n\
894 Type: {type_name}\nSchema:\n```json\n{schema_json}\n```"
895 );
896 augmented.insert(0, Message::from_legacy(Role::System, instruction));
897
898 let raw = self.chat(&augmented).await?;
899 let cleaned = strip_json_fences(&raw);
900 match serde_json::from_str::<T>(cleaned) {
901 Ok(val) => Ok(val),
902 Err(first_err) => {
903 augmented.push(Message::from_legacy(Role::Assistant, &raw));
904 augmented.push(Message::from_legacy(
905 Role::User,
906 format!(
907 "Your response was not valid JSON. Error: {first_err}. \
908 Please output ONLY valid JSON matching the schema."
909 ),
910 ));
911 let retry_raw = self.chat(&augmented).await?;
912 let retry_cleaned = strip_json_fences(&retry_raw);
913 serde_json::from_str::<T>(retry_cleaned).map_err(|e| {
914 LlmError::StructuredParse(format!("parse failed after retry: {e}"))
915 })
916 }
917 }
918 }
919}
920
921fn strip_json_fences(s: &str) -> &str {
925 s.trim()
926 .trim_start_matches("```json")
927 .trim_start_matches("```")
928 .trim_end_matches("```")
929 .trim()
930}
931
932#[cfg(test)]
933mod tests {
934 use tokio_stream::StreamExt;
935
936 use super::*;
937
938 struct StubProvider {
939 response: String,
940 }
941
942 impl LlmProvider for StubProvider {
943 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
944 Ok(self.response.clone())
945 }
946
947 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
948 let response = self.chat(messages).await?;
949 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
950 response,
951 )))))
952 }
953
954 fn supports_streaming(&self) -> bool {
955 false
956 }
957
958 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
959 Ok(vec![0.1, 0.2, 0.3])
960 }
961
962 fn supports_embeddings(&self) -> bool {
963 false
964 }
965
966 fn name(&self) -> &'static str {
967 "stub"
968 }
969 }
970
971 #[test]
972 fn context_window_default_returns_none() {
973 let provider = StubProvider {
974 response: String::new(),
975 };
976 assert!(provider.context_window().is_none());
977 }
978
979 #[test]
980 fn supports_streaming_default_returns_false() {
981 let provider = StubProvider {
982 response: String::new(),
983 };
984 assert!(!provider.supports_streaming());
985 }
986
987 #[tokio::test]
988 async fn chat_stream_default_yields_single_chunk() {
989 let provider = StubProvider {
990 response: "hello world".into(),
991 };
992 let messages = vec![Message {
993 role: Role::User,
994 content: "test".into(),
995 parts: vec![],
996 metadata: MessageMetadata::default(),
997 }];
998
999 let mut stream = provider.chat_stream(&messages).await.unwrap();
1000 let chunk = stream.next().await.unwrap().unwrap();
1001 assert!(matches!(chunk, StreamChunk::Content(s) if s == "hello world"));
1002 assert!(stream.next().await.is_none());
1003 }
1004
1005 #[tokio::test]
1006 async fn chat_stream_default_propagates_chat_error() {
1007 struct FailProvider;
1008
1009 impl LlmProvider for FailProvider {
1010 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1011 Err(LlmError::Unavailable)
1012 }
1013
1014 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1015 let response = self.chat(messages).await?;
1016 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1017 response,
1018 )))))
1019 }
1020
1021 fn supports_streaming(&self) -> bool {
1022 false
1023 }
1024
1025 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1026 Err(LlmError::Unavailable)
1027 }
1028
1029 fn supports_embeddings(&self) -> bool {
1030 false
1031 }
1032
1033 fn name(&self) -> &'static str {
1034 "fail"
1035 }
1036 }
1037
1038 let provider = FailProvider;
1039 let messages = vec![Message {
1040 role: Role::User,
1041 content: "test".into(),
1042 parts: vec![],
1043 metadata: MessageMetadata::default(),
1044 }];
1045
1046 let result = provider.chat_stream(&messages).await;
1047 assert!(result.is_err());
1048 if let Err(e) = result {
1049 assert!(e.to_string().contains("provider unavailable"));
1050 }
1051 }
1052
1053 #[tokio::test]
1054 async fn stub_provider_embed_returns_vector() {
1055 let provider = StubProvider {
1056 response: String::new(),
1057 };
1058 let embedding = provider.embed("test").await.unwrap();
1059 assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
1060 }
1061
1062 #[tokio::test]
1063 async fn fail_provider_embed_propagates_error() {
1064 struct FailProvider;
1065
1066 impl LlmProvider for FailProvider {
1067 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1068 Err(LlmError::Unavailable)
1069 }
1070
1071 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1072 let response = self.chat(messages).await?;
1073 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1074 response,
1075 )))))
1076 }
1077
1078 fn supports_streaming(&self) -> bool {
1079 false
1080 }
1081
1082 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1083 Err(LlmError::EmbedUnsupported {
1084 provider: "fail".into(),
1085 })
1086 }
1087
1088 fn supports_embeddings(&self) -> bool {
1089 false
1090 }
1091
1092 fn name(&self) -> &'static str {
1093 "fail"
1094 }
1095 }
1096
1097 let provider = FailProvider;
1098 let result = provider.embed("test").await;
1099 assert!(result.is_err());
1100 assert!(
1101 result
1102 .unwrap_err()
1103 .to_string()
1104 .contains("embedding not supported")
1105 );
1106 }
1107
1108 #[test]
1109 fn role_serialization() {
1110 let system = Role::System;
1111 let user = Role::User;
1112 let assistant = Role::Assistant;
1113
1114 assert_eq!(serde_json::to_string(&system).unwrap(), "\"system\"");
1115 assert_eq!(serde_json::to_string(&user).unwrap(), "\"user\"");
1116 assert_eq!(serde_json::to_string(&assistant).unwrap(), "\"assistant\"");
1117 }
1118
1119 #[test]
1120 fn role_deserialization() {
1121 let system: Role = serde_json::from_str("\"system\"").unwrap();
1122 let user: Role = serde_json::from_str("\"user\"").unwrap();
1123 let assistant: Role = serde_json::from_str("\"assistant\"").unwrap();
1124
1125 assert_eq!(system, Role::System);
1126 assert_eq!(user, Role::User);
1127 assert_eq!(assistant, Role::Assistant);
1128 }
1129
1130 #[test]
1131 fn message_clone() {
1132 let msg = Message {
1133 role: Role::User,
1134 content: "test".into(),
1135 parts: vec![],
1136 metadata: MessageMetadata::default(),
1137 };
1138 let cloned = msg.clone();
1139 assert_eq!(cloned.role, msg.role);
1140 assert_eq!(cloned.content, msg.content);
1141 }
1142
1143 #[test]
1144 fn message_debug() {
1145 let msg = Message {
1146 role: Role::Assistant,
1147 content: "response".into(),
1148 parts: vec![],
1149 metadata: MessageMetadata::default(),
1150 };
1151 let debug = format!("{msg:?}");
1152 assert!(debug.contains("Assistant"));
1153 assert!(debug.contains("response"));
1154 }
1155
1156 #[test]
1157 fn message_serialization() {
1158 let msg = Message {
1159 role: Role::User,
1160 content: "hello".into(),
1161 parts: vec![],
1162 metadata: MessageMetadata::default(),
1163 };
1164 let json = serde_json::to_string(&msg).unwrap();
1165 assert!(json.contains("\"role\":\"user\""));
1166 assert!(json.contains("\"content\":\"hello\""));
1167 }
1168
1169 #[test]
1170 fn message_part_serde_round_trip() {
1171 let parts = vec![
1172 MessagePart::Text {
1173 text: "hello".into(),
1174 },
1175 MessagePart::ToolOutput {
1176 tool_name: "bash".into(),
1177 body: "output".into(),
1178 compacted_at: None,
1179 },
1180 MessagePart::Recall {
1181 text: "recall".into(),
1182 },
1183 MessagePart::CodeContext {
1184 text: "code".into(),
1185 },
1186 MessagePart::Summary {
1187 text: "summary".into(),
1188 },
1189 ];
1190 let json = serde_json::to_string(&parts).unwrap();
1191 let deserialized: Vec<MessagePart> = serde_json::from_str(&json).unwrap();
1192 assert_eq!(deserialized.len(), 5);
1193 }
1194
1195 #[test]
1196 fn from_legacy_creates_empty_parts() {
1197 let msg = Message::from_legacy(Role::User, "hello");
1198 assert_eq!(msg.role, Role::User);
1199 assert_eq!(msg.content, "hello");
1200 assert!(msg.parts.is_empty());
1201 assert_eq!(msg.to_llm_content(), "hello");
1202 }
1203
1204 #[test]
1205 fn from_parts_flattens_content() {
1206 let msg = Message::from_parts(
1207 Role::System,
1208 vec![MessagePart::Recall {
1209 text: "recalled data".into(),
1210 }],
1211 );
1212 assert_eq!(msg.content, "recalled data");
1213 assert_eq!(msg.to_llm_content(), "recalled data");
1214 assert_eq!(msg.parts.len(), 1);
1215 }
1216
1217 #[test]
1218 fn from_parts_tool_output_format() {
1219 let msg = Message::from_parts(
1220 Role::User,
1221 vec![MessagePart::ToolOutput {
1222 tool_name: "bash".into(),
1223 body: "hello world".into(),
1224 compacted_at: None,
1225 }],
1226 );
1227 assert!(msg.content.contains("[tool output: bash]"));
1228 assert!(msg.content.contains("hello world"));
1229 }
1230
1231 #[test]
1232 fn message_deserializes_without_parts() {
1233 let json = r#"{"role":"user","content":"hello"}"#;
1234 let msg: Message = serde_json::from_str(json).unwrap();
1235 assert_eq!(msg.content, "hello");
1236 assert!(msg.parts.is_empty());
1237 }
1238
1239 #[test]
1240 fn flatten_skips_compacted_tool_output_empty_body() {
1241 let msg = Message::from_parts(
1243 Role::User,
1244 vec![
1245 MessagePart::Text {
1246 text: "prefix ".into(),
1247 },
1248 MessagePart::ToolOutput {
1249 tool_name: "bash".into(),
1250 body: String::new(),
1251 compacted_at: Some(1234),
1252 },
1253 MessagePart::Text {
1254 text: " suffix".into(),
1255 },
1256 ],
1257 );
1258 assert!(msg.content.contains("(pruned)"));
1259 assert!(msg.content.contains("prefix "));
1260 assert!(msg.content.contains(" suffix"));
1261 }
1262
1263 #[test]
1264 fn flatten_compacted_tool_output_with_reference_renders_body() {
1265 let ref_notice = "[tool output pruned; full content at /tmp/overflow/big.txt]";
1267 let msg = Message::from_parts(
1268 Role::User,
1269 vec![MessagePart::ToolOutput {
1270 tool_name: "bash".into(),
1271 body: ref_notice.into(),
1272 compacted_at: Some(1234),
1273 }],
1274 );
1275 assert!(msg.content.contains(ref_notice));
1276 assert!(!msg.content.contains("(pruned)"));
1277 }
1278
1279 #[test]
1280 fn rebuild_content_syncs_after_mutation() {
1281 let mut msg = Message::from_parts(
1282 Role::User,
1283 vec![MessagePart::ToolOutput {
1284 tool_name: "bash".into(),
1285 body: "original".into(),
1286 compacted_at: None,
1287 }],
1288 );
1289 assert!(msg.content.contains("original"));
1290
1291 if let MessagePart::ToolOutput {
1292 ref mut compacted_at,
1293 ref mut body,
1294 ..
1295 } = msg.parts[0]
1296 {
1297 *compacted_at = Some(999);
1298 body.clear(); }
1300 msg.rebuild_content();
1301
1302 assert!(msg.content.contains("(pruned)"));
1303 assert!(!msg.content.contains("original"));
1304 }
1305
1306 #[test]
1307 fn message_part_tool_use_serde_round_trip() {
1308 let part = MessagePart::ToolUse {
1309 id: "toolu_123".into(),
1310 name: "bash".into(),
1311 input: serde_json::json!({"command": "ls"}),
1312 };
1313 let json = serde_json::to_string(&part).unwrap();
1314 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1315 if let MessagePart::ToolUse { id, name, input } = deserialized {
1316 assert_eq!(id, "toolu_123");
1317 assert_eq!(name, "bash");
1318 assert_eq!(input["command"], "ls");
1319 } else {
1320 panic!("expected ToolUse");
1321 }
1322 }
1323
1324 #[test]
1325 fn message_part_tool_result_serde_round_trip() {
1326 let part = MessagePart::ToolResult {
1327 tool_use_id: "toolu_123".into(),
1328 content: "file1.rs\nfile2.rs".into(),
1329 is_error: false,
1330 };
1331 let json = serde_json::to_string(&part).unwrap();
1332 let deserialized: MessagePart = serde_json::from_str(&json).unwrap();
1333 if let MessagePart::ToolResult {
1334 tool_use_id,
1335 content,
1336 is_error,
1337 } = deserialized
1338 {
1339 assert_eq!(tool_use_id, "toolu_123");
1340 assert_eq!(content, "file1.rs\nfile2.rs");
1341 assert!(!is_error);
1342 } else {
1343 panic!("expected ToolResult");
1344 }
1345 }
1346
1347 #[test]
1348 fn message_part_tool_result_is_error_default() {
1349 let json = r#"{"kind":"tool_result","tool_use_id":"id","content":"err"}"#;
1350 let part: MessagePart = serde_json::from_str(json).unwrap();
1351 if let MessagePart::ToolResult { is_error, .. } = part {
1352 assert!(!is_error);
1353 } else {
1354 panic!("expected ToolResult");
1355 }
1356 }
1357
1358 #[test]
1359 fn chat_response_construction() {
1360 let text = ChatResponse::Text("hello".into());
1361 assert!(matches!(text, ChatResponse::Text(s) if s == "hello"));
1362
1363 let tool_use = ChatResponse::ToolUse {
1364 text: Some("I'll run that".into()),
1365 tool_calls: vec![ToolUseRequest {
1366 id: "1".into(),
1367 name: "bash".into(),
1368 input: serde_json::json!({}),
1369 }],
1370 thinking_blocks: vec![],
1371 };
1372 assert!(matches!(tool_use, ChatResponse::ToolUse { .. }));
1373 }
1374
1375 #[test]
1376 fn flatten_parts_tool_use() {
1377 let msg = Message::from_parts(
1378 Role::Assistant,
1379 vec![MessagePart::ToolUse {
1380 id: "t1".into(),
1381 name: "bash".into(),
1382 input: serde_json::json!({"command": "ls"}),
1383 }],
1384 );
1385 assert!(msg.content.contains("[tool_use: bash(t1)]"));
1386 }
1387
1388 #[test]
1389 fn flatten_parts_tool_result() {
1390 let msg = Message::from_parts(
1391 Role::User,
1392 vec![MessagePart::ToolResult {
1393 tool_use_id: "t1".into(),
1394 content: "output here".into(),
1395 is_error: false,
1396 }],
1397 );
1398 assert!(msg.content.contains("[tool_result: t1]"));
1399 assert!(msg.content.contains("output here"));
1400 }
1401
1402 #[test]
1403 fn tool_definition_serde_round_trip() {
1404 let def = ToolDefinition {
1405 name: "bash".into(),
1406 description: "Execute a shell command".into(),
1407 parameters: serde_json::json!({"type": "object"}),
1408 output_schema: None,
1409 };
1410 let json = serde_json::to_string(&def).unwrap();
1411 let deserialized: ToolDefinition = serde_json::from_str(&json).unwrap();
1412 assert_eq!(deserialized.name, "bash");
1413 assert_eq!(deserialized.description, "Execute a shell command");
1414 }
1415
1416 #[tokio::test]
1417 async fn chat_with_tools_default_delegates_to_chat() {
1418 let provider = StubProvider {
1419 response: "hello".into(),
1420 };
1421 let messages = vec![Message::from_legacy(Role::User, "test")];
1422 let result = provider.chat_with_tools(&messages, &[]).await.unwrap();
1423 assert!(matches!(result, ChatResponse::Text(s) if s == "hello"));
1424 }
1425
1426 #[test]
1427 fn tool_output_compacted_at_serde_default() {
1428 let json = r#"{"kind":"tool_output","tool_name":"bash","body":"out"}"#;
1429 let part: MessagePart = serde_json::from_str(json).unwrap();
1430 if let MessagePart::ToolOutput { compacted_at, .. } = part {
1431 assert!(compacted_at.is_none());
1432 } else {
1433 panic!("expected ToolOutput");
1434 }
1435 }
1436
1437 #[test]
1440 fn strip_json_fences_plain_json() {
1441 assert_eq!(strip_json_fences(r#"{"a": 1}"#), r#"{"a": 1}"#);
1442 }
1443
1444 #[test]
1445 fn strip_json_fences_with_json_fence() {
1446 assert_eq!(strip_json_fences("```json\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1447 }
1448
1449 #[test]
1450 fn strip_json_fences_with_plain_fence() {
1451 assert_eq!(strip_json_fences("```\n{\"a\": 1}\n```"), r#"{"a": 1}"#);
1452 }
1453
1454 #[test]
1455 fn strip_json_fences_whitespace() {
1456 assert_eq!(strip_json_fences(" \n "), "");
1457 }
1458
1459 #[test]
1460 fn strip_json_fences_empty() {
1461 assert_eq!(strip_json_fences(""), "");
1462 }
1463
1464 #[test]
1465 fn strip_json_fences_outer_whitespace() {
1466 assert_eq!(
1467 strip_json_fences(" ```json\n{\"a\": 1}\n``` "),
1468 r#"{"a": 1}"#
1469 );
1470 }
1471
1472 #[test]
1473 fn strip_json_fences_only_opening_fence() {
1474 assert_eq!(strip_json_fences("```json\n{\"a\": 1}"), r#"{"a": 1}"#);
1475 }
1476
1477 #[derive(Debug, serde::Deserialize, schemars::JsonSchema, PartialEq)]
1480 struct TestOutput {
1481 value: String,
1482 }
1483
1484 struct SequentialStub {
1485 responses: std::sync::Mutex<Vec<Result<String, LlmError>>>,
1486 }
1487
1488 impl SequentialStub {
1489 fn new(responses: Vec<Result<String, LlmError>>) -> Self {
1490 Self {
1491 responses: std::sync::Mutex::new(responses),
1492 }
1493 }
1494 }
1495
1496 impl LlmProvider for SequentialStub {
1497 async fn chat(&self, _messages: &[Message]) -> Result<String, LlmError> {
1498 let mut responses = self.responses.lock().unwrap();
1499 if responses.is_empty() {
1500 return Err(LlmError::Other("no more responses".into()));
1501 }
1502 responses.remove(0)
1503 }
1504
1505 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
1506 let response = self.chat(messages).await?;
1507 Ok(Box::pin(tokio_stream::once(Ok(StreamChunk::Content(
1508 response,
1509 )))))
1510 }
1511
1512 fn supports_streaming(&self) -> bool {
1513 false
1514 }
1515
1516 async fn embed(&self, _text: &str) -> Result<Vec<f32>, LlmError> {
1517 Err(LlmError::EmbedUnsupported {
1518 provider: "sequential-stub".into(),
1519 })
1520 }
1521
1522 fn supports_embeddings(&self) -> bool {
1523 false
1524 }
1525
1526 fn name(&self) -> &'static str {
1527 "sequential-stub"
1528 }
1529 }
1530
1531 #[tokio::test]
1532 async fn chat_typed_happy_path() {
1533 let provider = StubProvider {
1534 response: r#"{"value": "hello"}"#.into(),
1535 };
1536 let messages = vec![Message::from_legacy(Role::User, "test")];
1537 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1538 assert_eq!(
1539 result,
1540 TestOutput {
1541 value: "hello".into()
1542 }
1543 );
1544 }
1545
1546 #[tokio::test]
1547 async fn chat_typed_retry_succeeds() {
1548 let provider = SequentialStub::new(vec![
1549 Ok("not valid json".into()),
1550 Ok(r#"{"value": "ok"}"#.into()),
1551 ]);
1552 let messages = vec![Message::from_legacy(Role::User, "test")];
1553 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1554 assert_eq!(result, TestOutput { value: "ok".into() });
1555 }
1556
1557 #[tokio::test]
1558 async fn chat_typed_both_fail() {
1559 let provider = SequentialStub::new(vec![Ok("bad json".into()), Ok("still bad".into())]);
1560 let messages = vec![Message::from_legacy(Role::User, "test")];
1561 let result = provider.chat_typed::<TestOutput>(&messages).await;
1562 let err = result.unwrap_err();
1563 assert!(err.to_string().contains("parse failed after retry"));
1564 }
1565
1566 #[tokio::test]
1567 async fn chat_typed_chat_error_propagates() {
1568 let provider = SequentialStub::new(vec![Err(LlmError::Unavailable)]);
1569 let messages = vec![Message::from_legacy(Role::User, "test")];
1570 let result = provider.chat_typed::<TestOutput>(&messages).await;
1571 assert!(matches!(result, Err(LlmError::Unavailable)));
1572 }
1573
1574 #[tokio::test]
1575 async fn chat_typed_strips_fences() {
1576 let provider = StubProvider {
1577 response: "```json\n{\"value\": \"fenced\"}\n```".into(),
1578 };
1579 let messages = vec![Message::from_legacy(Role::User, "test")];
1580 let result: TestOutput = provider.chat_typed(&messages).await.unwrap();
1581 assert_eq!(
1582 result,
1583 TestOutput {
1584 value: "fenced".into()
1585 }
1586 );
1587 }
1588
1589 #[test]
1590 fn supports_structured_output_default_false() {
1591 let provider = StubProvider {
1592 response: String::new(),
1593 };
1594 assert!(!provider.supports_structured_output());
1595 }
1596
1597 #[test]
1598 fn structured_parse_error_display() {
1599 let err = LlmError::StructuredParse("test error".into());
1600 assert_eq!(
1601 err.to_string(),
1602 "structured output parse failed: test error"
1603 );
1604 }
1605
1606 #[test]
1607 fn message_part_image_roundtrip_json() {
1608 let part = MessagePart::Image(Box::new(ImageData {
1609 data: vec![1, 2, 3, 4],
1610 mime_type: "image/jpeg".into(),
1611 }));
1612 let json = serde_json::to_string(&part).unwrap();
1613 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1614 match decoded {
1615 MessagePart::Image(img) => {
1616 assert_eq!(img.data, vec![1, 2, 3, 4]);
1617 assert_eq!(img.mime_type, "image/jpeg");
1618 }
1619 _ => panic!("expected Image variant"),
1620 }
1621 }
1622
1623 #[test]
1624 fn flatten_parts_includes_image_placeholder() {
1625 let msg = Message::from_parts(
1626 Role::User,
1627 vec![
1628 MessagePart::Text {
1629 text: "see this".into(),
1630 },
1631 MessagePart::Image(Box::new(ImageData {
1632 data: vec![0u8; 100],
1633 mime_type: "image/png".into(),
1634 })),
1635 ],
1636 );
1637 let content = msg.to_llm_content();
1638 assert!(content.contains("see this"));
1639 assert!(content.contains("[image: image/png"));
1640 }
1641
1642 #[test]
1643 fn supports_vision_default_false() {
1644 let provider = StubProvider {
1645 response: String::new(),
1646 };
1647 assert!(!provider.supports_vision());
1648 }
1649
1650 #[test]
1651 fn message_metadata_default_both_visible() {
1652 let m = MessageMetadata::default();
1653 assert!(m.visibility.is_agent_visible());
1654 assert!(m.visibility.is_user_visible());
1655 assert_eq!(m.visibility, MessageVisibility::Both);
1656 assert!(m.compacted_at.is_none());
1657 }
1658
1659 #[test]
1660 fn message_metadata_agent_only() {
1661 let m = MessageMetadata::agent_only();
1662 assert!(m.visibility.is_agent_visible());
1663 assert!(!m.visibility.is_user_visible());
1664 assert_eq!(m.visibility, MessageVisibility::AgentOnly);
1665 }
1666
1667 #[test]
1668 fn message_metadata_user_only() {
1669 let m = MessageMetadata::user_only();
1670 assert!(!m.visibility.is_agent_visible());
1671 assert!(m.visibility.is_user_visible());
1672 assert_eq!(m.visibility, MessageVisibility::UserOnly);
1673 }
1674
1675 #[test]
1676 fn message_metadata_serde_default() {
1677 let json = r#"{"role":"user","content":"hello"}"#;
1678 let msg: Message = serde_json::from_str(json).unwrap();
1679 assert!(msg.metadata.visibility.is_agent_visible());
1680 assert!(msg.metadata.visibility.is_user_visible());
1681 }
1682
1683 #[test]
1684 fn message_metadata_round_trip() {
1685 let msg = Message {
1686 role: Role::User,
1687 content: "test".into(),
1688 parts: vec![],
1689 metadata: MessageMetadata::agent_only(),
1690 };
1691 let json = serde_json::to_string(&msg).unwrap();
1692 let decoded: Message = serde_json::from_str(&json).unwrap();
1693 assert!(decoded.metadata.visibility.is_agent_visible());
1694 assert!(!decoded.metadata.visibility.is_user_visible());
1695 assert_eq!(decoded.metadata.visibility, MessageVisibility::AgentOnly);
1696 }
1697
1698 #[test]
1699 fn message_part_compaction_round_trip() {
1700 let part = MessagePart::Compaction {
1701 summary: "Context was summarized.".to_owned(),
1702 };
1703 let json = serde_json::to_string(&part).unwrap();
1704 let decoded: MessagePart = serde_json::from_str(&json).unwrap();
1705 assert!(
1706 matches!(decoded, MessagePart::Compaction { summary } if summary == "Context was summarized.")
1707 );
1708 }
1709
1710 #[test]
1711 fn flatten_parts_compaction_contributes_no_text() {
1712 let parts = vec![
1715 MessagePart::Text {
1716 text: "Hello".to_owned(),
1717 },
1718 MessagePart::Compaction {
1719 summary: "Summary".to_owned(),
1720 },
1721 ];
1722 let msg = Message::from_parts(Role::Assistant, parts);
1723 assert_eq!(msg.content.trim(), "Hello");
1725 }
1726
1727 #[test]
1728 fn stream_chunk_compaction_variant() {
1729 let chunk = StreamChunk::Compaction("A summary".to_owned());
1730 assert!(matches!(chunk, StreamChunk::Compaction(s) if s == "A summary"));
1731 }
1732
1733 #[test]
1734 fn short_type_name_extracts_last_segment() {
1735 struct MyOutput;
1736 assert_eq!(short_type_name::<MyOutput>(), "MyOutput");
1737 }
1738
1739 #[test]
1740 fn short_type_name_primitive_returns_full_name() {
1741 assert_eq!(short_type_name::<u32>(), "u32");
1743 assert_eq!(short_type_name::<bool>(), "bool");
1744 }
1745
1746 #[test]
1747 fn short_type_name_nested_path_returns_last() {
1748 assert_eq!(
1750 short_type_name::<std::collections::HashMap<u32, u32>>(),
1751 "HashMap<u32, u32>"
1752 );
1753 }
1754
1755 #[test]
1758 fn summary_roundtrip() {
1759 let part = MessagePart::Summary {
1760 text: "hello".to_string(),
1761 };
1762 let json = serde_json::to_string(&part).expect("serialization must not fail");
1763 assert!(
1764 json.contains("\"kind\":\"summary\""),
1765 "must use internally-tagged format, got: {json}"
1766 );
1767 assert!(
1768 !json.contains("\"Summary\""),
1769 "must not use externally-tagged format, got: {json}"
1770 );
1771 let decoded: MessagePart =
1772 serde_json::from_str(&json).expect("deserialization must not fail");
1773 match decoded {
1774 MessagePart::Summary { text } => assert_eq!(text, "hello"),
1775 other => panic!("expected MessagePart::Summary, got {other:?}"),
1776 }
1777 }
1778
1779 #[tokio::test]
1780 async fn embed_batch_default_empty_returns_empty() {
1781 let provider = StubProvider {
1782 response: String::new(),
1783 };
1784 let result = provider.embed_batch(&[]).await.unwrap();
1785 assert!(result.is_empty());
1786 }
1787
1788 #[tokio::test]
1789 async fn embed_batch_default_calls_embed_sequentially() {
1790 let provider = StubProvider {
1791 response: String::new(),
1792 };
1793 let texts = ["hello", "world", "foo"];
1794 let result = provider.embed_batch(&texts).await.unwrap();
1795 assert_eq!(result.len(), 3);
1796 for vec in &result {
1798 assert_eq!(vec, &[0.1_f32, 0.2, 0.3]);
1799 }
1800 }
1801
1802 #[test]
1803 fn message_visibility_db_roundtrip_both() {
1804 assert_eq!(MessageVisibility::Both.as_db_str(), "both");
1805 assert_eq!(
1806 MessageVisibility::from_db_str("both"),
1807 MessageVisibility::Both
1808 );
1809 }
1810
1811 #[test]
1812 fn message_visibility_db_roundtrip_agent_only() {
1813 assert_eq!(MessageVisibility::AgentOnly.as_db_str(), "agent_only");
1814 assert_eq!(
1815 MessageVisibility::from_db_str("agent_only"),
1816 MessageVisibility::AgentOnly
1817 );
1818 }
1819
1820 #[test]
1821 fn message_visibility_db_roundtrip_user_only() {
1822 assert_eq!(MessageVisibility::UserOnly.as_db_str(), "user_only");
1823 assert_eq!(
1824 MessageVisibility::from_db_str("user_only"),
1825 MessageVisibility::UserOnly
1826 );
1827 }
1828
1829 #[test]
1830 fn message_visibility_from_db_str_unknown_defaults_to_both() {
1831 assert_eq!(
1832 MessageVisibility::from_db_str("unknown_future_value"),
1833 MessageVisibility::Both
1834 );
1835 assert_eq!(MessageVisibility::from_db_str(""), MessageVisibility::Both);
1836 }
1837}