1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10#[serde(tag = "role", rename_all = "lowercase")]
11pub enum Message {
12 Human(HumanMessage),
14 Ai(AiMessage),
16 System(SystemMessage),
18 Tool(ToolMessage),
20}
21
22impl Message {
24 pub fn human(content: impl Into<String>) -> Self {
26 Self::Human(HumanMessage {
27 content: content.into(),
28 parts: Vec::new(),
29 })
30 }
31
32 pub fn human_with_parts(
36 content: impl Into<String>,
37 parts: Vec<crate::content::ContentPart>,
38 ) -> Self {
39 Self::Human(HumanMessage {
40 content: content.into(),
41 parts,
42 })
43 }
44
45 pub fn ai(content: impl Into<String>) -> Self {
47 Self::Ai(AiMessage {
48 content: content.into(),
49 tool_calls: Vec::new(),
50 parts: Vec::new(),
51 })
52 }
53
54 pub fn ai_with_parts(
56 content: impl Into<String>,
57 parts: Vec<crate::content::ContentPart>,
58 ) -> Self {
59 Self::Ai(AiMessage {
60 content: content.into(),
61 tool_calls: Vec::new(),
62 parts,
63 })
64 }
65
66 pub fn system(content: impl Into<String>) -> Self {
68 Self::System(SystemMessage {
69 content: content.into(),
70 })
71 }
72
73 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
75 Self::Tool(ToolMessage {
76 tool_call_id: call_id.into(),
77 content: content.into(),
78 })
79 }
80
81 pub fn content(&self) -> &str {
84 match self {
85 Self::Human(m) => &m.content,
86 Self::Ai(m) => &m.content,
87 Self::System(m) => &m.content,
88 Self::Tool(m) => &m.content,
89 }
90 }
91
92 pub fn tool_calls(&self) -> &[ToolCall] {
94 match self {
95 Self::Ai(m) => &m.tool_calls,
96 _ => &[],
97 }
98 }
99
100 pub fn has_tool_calls(&self) -> bool {
102 matches!(self, Self::Ai(m) if !m.tool_calls.is_empty())
103 }
104
105 pub fn parts(&self) -> &[crate::content::ContentPart] {
108 match self {
109 Self::Human(m) => &m.parts,
110 Self::Ai(m) => &m.parts,
111 _ => &[],
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
118pub struct HumanMessage {
119 pub content: String,
121 #[serde(default, skip_serializing_if = "Vec::is_empty")]
124 pub parts: Vec<crate::content::ContentPart>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
129pub struct AiMessage {
130 pub content: String,
132 #[serde(default, skip_serializing_if = "Vec::is_empty")]
134 pub tool_calls: Vec<ToolCall>,
135 #[serde(default, skip_serializing_if = "Vec::is_empty")]
137 pub parts: Vec<crate::content::ContentPart>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
142pub struct SystemMessage {
143 pub content: String,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
149pub struct ToolMessage {
150 pub tool_call_id: String,
152 pub content: String,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
158pub struct ToolCall {
159 pub id: String,
161 pub name: String,
163 pub arguments: serde_json::Value,
165}
166
167impl From<String> for Message {
168 fn from(s: String) -> Self {
169 Self::human(s)
170 }
171}
172
173impl From<&str> for Message {
174 fn from(s: &str) -> Self {
175 Self::human(s)
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
188#[serde(tag = "role", rename_all = "lowercase")]
189pub enum MessageChunk {
190 Human(HumanChunk),
192 Ai(AiChunk),
194 System(SystemChunk),
196 Tool(ToolChunk),
198}
199
200#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
202pub struct HumanChunk {
203 pub content: String,
205}
206
207#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
215pub struct AiChunk {
216 pub content: String,
218 #[serde(default, skip_serializing_if = "Vec::is_empty")]
222 pub tool_calls: Vec<ToolCallChunk>,
223 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
226 pub extras: serde_json::Map<String, serde_json::Value>,
227}
228
229#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
231pub struct SystemChunk {
232 pub content: String,
234}
235
236#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
238pub struct ToolChunk {
239 pub tool_call_id: String,
241 pub content: String,
243}
244
245#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
247pub struct ToolCallChunk {
248 pub index: usize,
251 pub id: String,
254 pub name: String,
256 pub arguments: String,
258 #[serde(default, skip_serializing_if = "serde_json::Map::is_empty")]
261 pub extras: serde_json::Map<String, serde_json::Value>,
262}
263
264impl MessageChunk {
265 pub fn content(&self) -> &str {
267 match self {
268 Self::Human(c) => &c.content,
269 Self::Ai(c) => &c.content,
270 Self::System(c) => &c.content,
271 Self::Tool(c) => &c.content,
272 }
273 }
274
275 pub fn extend(&mut self, other: MessageChunk) -> crate::Result<()> {
278 match (self, other) {
279 (Self::Human(a), Self::Human(b)) => {
280 a.content.push_str(&b.content);
281 Ok(())
282 }
283 (Self::System(a), Self::System(b)) => {
284 a.content.push_str(&b.content);
285 Ok(())
286 }
287 (Self::Tool(a), Self::Tool(b)) => {
288 if a.tool_call_id.is_empty() {
289 a.tool_call_id = b.tool_call_id;
290 }
291 a.content.push_str(&b.content);
292 Ok(())
293 }
294 (Self::Ai(a), Self::Ai(b)) => {
295 a.content.push_str(&b.content);
296 for tc in b.tool_calls {
297 match a.tool_calls.iter_mut().find(|x| x.index == tc.index) {
298 Some(existing) => {
299 if existing.id.is_empty() {
300 existing.id = tc.id;
301 }
302 if existing.name.is_empty() {
303 existing.name = tc.name;
304 }
305 existing.arguments.push_str(&tc.arguments);
306 for (k, v) in tc.extras {
307 existing.extras.insert(k, v);
308 }
309 }
310 None => a.tool_calls.push(tc),
311 }
312 }
313 for (k, v) in b.extras {
314 a.extras.insert(k, v);
315 }
316 Ok(())
317 }
318 _ => Err(crate::CognisError::Internal(
319 "cannot merge MessageChunks of different roles".into(),
320 )),
321 }
322 }
323}
324
325pub fn message_from_chunks<I: IntoIterator<Item = MessageChunk>>(
329 chunks: I,
330) -> crate::Result<Message> {
331 let mut iter = chunks.into_iter();
332 let mut acc = match iter.next() {
333 Some(c) => c,
334 None => {
335 return Err(crate::CognisError::Internal(
336 "message_from_chunks: empty chunk stream".into(),
337 ))
338 }
339 };
340 for next in iter {
341 acc.extend(next)?;
342 }
343 Ok(match acc {
344 MessageChunk::Human(c) => Message::Human(HumanMessage {
345 content: c.content,
346 parts: Vec::new(),
347 }),
348 MessageChunk::System(c) => Message::System(SystemMessage { content: c.content }),
349 MessageChunk::Tool(c) => Message::Tool(ToolMessage {
350 tool_call_id: c.tool_call_id,
351 content: c.content,
352 }),
353 MessageChunk::Ai(c) => {
354 let tool_calls = c
355 .tool_calls
356 .into_iter()
357 .map(|tc| {
358 let arguments = if tc.arguments.is_empty() {
359 serde_json::Value::Null
360 } else {
361 serde_json::from_str(&tc.arguments).map_err(|e| {
362 crate::CognisError::Serialization(format!(
363 "tool call `{}` arguments: {e}",
364 tc.name
365 ))
366 })?
367 };
368 Ok(ToolCall {
369 id: tc.id,
370 name: tc.name,
371 arguments,
372 })
373 })
374 .collect::<crate::Result<Vec<_>>>()?;
375 Message::Ai(AiMessage {
376 content: c.content,
377 tool_calls,
378 parts: Vec::new(),
379 })
380 }
381 })
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
393pub struct RemoveMessage {
394 pub id: String,
397}
398
399impl RemoveMessage {
400 pub const ALL: &'static str = "__all__";
402
403 pub fn new(id: impl Into<String>) -> Self {
405 Self { id: id.into() }
406 }
407
408 pub fn all() -> Self {
410 Self {
411 id: Self::ALL.to_string(),
412 }
413 }
414
415 pub fn is_all(&self) -> bool {
417 self.id == Self::ALL
418 }
419}
420
421#[derive(Debug, Clone, Copy, PartialEq, Eq)]
429pub enum TrimStrategy {
430 First,
433 Last,
436}
437
438pub fn trim_messages<T: crate::tokenizer::Tokenizer + ?Sized>(
444 messages: &[Message],
445 max_tokens: usize,
446 tokenizer: &T,
447 strategy: TrimStrategy,
448) -> Vec<Message> {
449 if messages.is_empty() {
450 return Vec::new();
451 }
452 let pinned = matches!(messages.first(), Some(Message::System(_))) as usize;
453 let pinned_msgs: Vec<Message> = messages[..pinned].to_vec();
454 let pinned_cost: usize = pinned_msgs
455 .iter()
456 .map(|m| tokenizer.count(m.content()))
457 .sum();
458 let budget = max_tokens.saturating_sub(pinned_cost);
459
460 let candidates: &[Message] = &messages[pinned..];
461 let costs: Vec<usize> = candidates
462 .iter()
463 .map(|m| tokenizer.count(m.content()))
464 .collect();
465
466 let order: Vec<usize> = match strategy {
467 TrimStrategy::First => (0..candidates.len()).rev().collect(),
469 TrimStrategy::Last => (0..candidates.len()).collect(),
471 };
472
473 let mut keep = vec![false; candidates.len()];
474 let mut running = 0usize;
475 for idx in order {
476 let cost = costs[idx];
477 if running + cost > budget {
478 break;
479 }
480 running += cost;
481 keep[idx] = true;
482 }
483
484 let mut out = pinned_msgs;
485 out.extend(candidates.iter().zip(keep.iter()).filter_map(|(m, &k)| {
486 if k {
487 Some(m.clone())
488 } else {
489 None
490 }
491 }));
492 out
493}
494
495pub fn trim_messages_custom<F>(
505 messages: &[Message],
506 tokenizer: &dyn crate::tokenizer::Tokenizer,
507 mut keep: F,
508) -> Vec<Message>
509where
510 F: FnMut(&Message, usize, usize) -> bool,
511{
512 let mut out = Vec::with_capacity(messages.len());
513 let mut running = 0usize;
514 for (i, m) in messages.iter().enumerate() {
515 let cost = tokenizer.count(m.content());
516 if keep(m, running, i) {
517 running += cost;
518 out.push(m.clone());
519 }
520 }
521 out
522}
523
524pub fn merge_message_runs(messages: &[Message]) -> Vec<Message> {
528 let mut out: Vec<Message> = Vec::with_capacity(messages.len());
529 for msg in messages {
530 let same_role = match (out.last(), msg) {
531 (Some(Message::Human(_)), Message::Human(_)) => true,
532 (Some(Message::Ai(_)), Message::Ai(_)) => true,
533 (Some(Message::System(_)), Message::System(_)) => true,
534 (Some(Message::Tool(a)), Message::Tool(b)) => a.tool_call_id == b.tool_call_id,
537 _ => false,
538 };
539 if !same_role {
540 out.push(msg.clone());
541 continue;
542 }
543 let last = out.last_mut().expect("checked non-empty above");
544 match (last, msg) {
545 (Message::Human(a), Message::Human(b)) => {
546 if !a.content.is_empty() && !b.content.is_empty() {
547 a.content.push_str("\n\n");
548 }
549 a.content.push_str(&b.content);
550 a.parts.extend(b.parts.iter().cloned());
551 }
552 (Message::Ai(a), Message::Ai(b)) => {
553 if !a.content.is_empty() && !b.content.is_empty() {
554 a.content.push_str("\n\n");
555 }
556 a.content.push_str(&b.content);
557 a.tool_calls.extend(b.tool_calls.iter().cloned());
558 a.parts.extend(b.parts.iter().cloned());
559 }
560 (Message::System(a), Message::System(b)) => {
561 if !a.content.is_empty() && !b.content.is_empty() {
562 a.content.push_str("\n\n");
563 }
564 a.content.push_str(&b.content);
565 }
566 (Message::Tool(a), Message::Tool(b)) => {
567 if !a.content.is_empty() && !b.content.is_empty() {
568 a.content.push_str("\n\n");
569 }
570 a.content.push_str(&b.content);
571 }
572 _ => unreachable!(),
573 }
574 }
575 out
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581
582 #[test]
583 fn convenience_constructors() {
584 assert_eq!(Message::human("hi").content(), "hi");
585 assert_eq!(Message::ai("hello").content(), "hello");
586 assert_eq!(Message::system("be terse").content(), "be terse");
587 let t = Message::tool("call_1", "result");
588 assert_eq!(t.content(), "result");
589 if let Message::Tool(tm) = t {
590 assert_eq!(tm.tool_call_id, "call_1");
591 }
592 }
593
594 #[test]
595 fn tool_calls_accessor() {
596 let m = Message::ai("none here");
597 assert!(m.tool_calls().is_empty());
598 assert!(!m.has_tool_calls());
599
600 let m = Message::Ai(AiMessage {
601 content: String::new(),
602 tool_calls: vec![ToolCall {
603 id: "c".into(),
604 name: "search".into(),
605 arguments: serde_json::json!({"q": "rust"}),
606 }],
607 parts: Vec::new(),
608 });
609 assert_eq!(m.tool_calls().len(), 1);
610 assert!(m.has_tool_calls());
611 }
612
613 #[test]
614 fn roundtrip_serde() {
615 let m = Message::human("hi");
616 let s = serde_json::to_string(&m).unwrap();
617 let back: Message = serde_json::from_str(&s).unwrap();
618 assert_eq!(m, back);
619 assert!(s.contains("\"role\":\"human\""));
620 }
621
622 #[test]
623 fn message_chunks_merge_text() {
624 let mut a = MessageChunk::Ai(AiChunk {
625 content: "Hel".into(),
626 ..Default::default()
627 });
628 a.extend(MessageChunk::Ai(AiChunk {
629 content: "lo".into(),
630 ..Default::default()
631 }))
632 .unwrap();
633 assert_eq!(a.content(), "Hello");
634 }
635
636 #[test]
637 fn message_chunks_merge_tool_call_arguments() {
638 let mut a = MessageChunk::Ai(AiChunk {
639 tool_calls: vec![ToolCallChunk {
640 index: 0,
641 id: "c1".into(),
642 name: "search".into(),
643 arguments: "{\"q\":\"ru".into(),
644 ..Default::default()
645 }],
646 ..Default::default()
647 });
648 a.extend(MessageChunk::Ai(AiChunk {
649 tool_calls: vec![ToolCallChunk {
650 index: 0,
651 arguments: "st\"}".into(),
652 ..Default::default()
653 }],
654 ..Default::default()
655 }))
656 .unwrap();
657 let final_msg = message_from_chunks(std::iter::once(a)).unwrap();
658 let calls = final_msg.tool_calls();
659 assert_eq!(calls.len(), 1);
660 assert_eq!(calls[0].name, "search");
661 assert_eq!(calls[0].arguments["q"], "rust");
662 }
663
664 #[test]
665 fn message_chunks_reject_role_mix() {
666 let mut a = MessageChunk::Ai(AiChunk::default());
667 let err = a
668 .extend(MessageChunk::Human(HumanChunk {
669 content: "x".into(),
670 }))
671 .unwrap_err();
672 assert!(matches!(err, crate::CognisError::Internal(_)));
673 }
674
675 #[test]
676 fn message_from_chunks_empty_errors() {
677 let err = message_from_chunks(std::iter::empty::<MessageChunk>()).unwrap_err();
678 assert!(matches!(err, crate::CognisError::Internal(_)));
679 }
680
681 #[test]
682 fn remove_message_constructors() {
683 let r = RemoveMessage::new("m1");
684 assert_eq!(r.id, "m1");
685 assert!(!r.is_all());
686 assert!(RemoveMessage::all().is_all());
687 }
688
689 #[test]
690 fn trim_messages_drops_oldest_first() {
691 let tok = crate::tokenizer::CharTokenizer;
692 let msgs = vec![
693 Message::system("sys"), Message::human("aaaaa"), Message::ai("bbbbb"), Message::human("ccccc"), ];
698 let out = trim_messages(&msgs, 13, &tok, TrimStrategy::First);
700 assert_eq!(out.len(), 3);
701 assert_eq!(out[0].content(), "sys");
702 assert_eq!(out[1].content(), "bbbbb");
703 assert_eq!(out[2].content(), "ccccc");
704 }
705
706 #[test]
707 fn trim_messages_drops_newest_first() {
708 let tok = crate::tokenizer::CharTokenizer;
709 let msgs = vec![
710 Message::human("aaaaa"),
711 Message::human("bbbbb"),
712 Message::human("ccccc"),
713 ];
714 let out = trim_messages(&msgs, 10, &tok, TrimStrategy::Last);
715 assert_eq!(out.len(), 2);
716 assert_eq!(out[0].content(), "aaaaa");
717 assert_eq!(out[1].content(), "bbbbb");
718 }
719
720 #[test]
721 fn trim_messages_returns_empty_when_budget_too_small_and_no_system() {
722 let tok = crate::tokenizer::CharTokenizer;
723 let msgs = vec![Message::human("longtext")];
724 let out = trim_messages(&msgs, 3, &tok, TrimStrategy::First);
725 assert!(out.is_empty());
726 }
727
728 #[test]
729 fn merge_message_runs_collapses_consecutive_same_role() {
730 let msgs = vec![
731 Message::system("sys"),
732 Message::human("a"),
733 Message::human("b"),
734 Message::ai("c"),
735 Message::human("d"),
736 Message::human("e"),
737 ];
738 let out = merge_message_runs(&msgs);
739 assert_eq!(out.len(), 4);
740 assert_eq!(out[1].content(), "a\n\nb");
741 assert_eq!(out[3].content(), "d\n\ne");
742 }
743
744 #[test]
745 fn message_chunks_merge_extras_map() {
746 let mut a = MessageChunk::Ai(AiChunk {
747 content: "x".into(),
748 extras: serde_json::Map::from_iter([(
749 "finish_reason".to_string(),
750 serde_json::Value::String("stop".into()),
751 )]),
752 ..Default::default()
753 });
754 a.extend(MessageChunk::Ai(AiChunk {
755 content: "y".into(),
756 extras: serde_json::Map::from_iter([(
757 "logprobs".to_string(),
758 serde_json::json!([{"token": "x"}]),
759 )]),
760 ..Default::default()
761 }))
762 .unwrap();
763 if let MessageChunk::Ai(ref ai) = a {
764 assert_eq!(ai.extras.get("finish_reason").unwrap(), "stop");
765 assert!(ai.extras.contains_key("logprobs"));
766 } else {
767 panic!("expected Ai");
768 }
769 }
770
771 #[test]
772 fn trim_messages_custom_uses_predicate() {
773 let tok = crate::tokenizer::CharTokenizer;
774 let msgs = vec![
775 Message::human("aaa"), Message::human("bbbbbbbb"), Message::human("c"), ];
779 let out = trim_messages_custom(&msgs, &tok, |m, _running, _i| {
781 m.content().starts_with('a') || m.content().starts_with('c')
782 });
783 assert_eq!(out.len(), 2);
784 assert_eq!(out[0].content(), "aaa");
785 assert_eq!(out[1].content(), "c");
786 }
787
788 #[test]
789 fn merge_message_runs_does_not_merge_tool_with_different_ids() {
790 let msgs = vec![Message::tool("c1", "first"), Message::tool("c2", "second")];
791 let out = merge_message_runs(&msgs);
792 assert_eq!(out.len(), 2);
793 }
794}