1use serde::{Deserialize, Serialize};
42use std::collections::HashMap;
43use uuid::Uuid;
44
45#[cfg(feature = "specta")]
46use specta::Type;
47
48#[cfg_attr(feature = "specta", derive(Type))]
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
56#[serde(rename_all = "lowercase")]
57pub enum ImageDetail {
58 Low,
59 High,
60 #[default]
61 Auto,
62}
63
64#[cfg_attr(feature = "specta", derive(Type))]
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
67#[serde(tag = "type", rename_all = "snake_case")]
68pub enum ImageSource {
69 Url { url: String },
71 Base64 {
73 media_type: String,
75 data: String,
77 },
78}
79
80#[cfg_attr(feature = "specta", derive(Type))]
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85#[serde(tag = "type", rename_all = "snake_case")]
86pub enum ContentPart {
87 Text { text: String },
89 Image {
91 source: ImageSource,
92 #[serde(skip_serializing_if = "Option::is_none")]
94 detail: Option<ImageDetail>,
95 },
96}
97
98impl From<&str> for ContentPart {
99 fn from(text: &str) -> Self {
100 ContentPart::Text {
101 text: text.to_string(),
102 }
103 }
104}
105
106impl From<String> for ContentPart {
107 fn from(text: String) -> Self {
108 ContentPart::Text { text }
109 }
110}
111
112#[cfg_attr(feature = "specta", derive(Type))]
114#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
115pub struct ToolCall {
116 id: String,
118 name: String,
120 args: serde_json::Value,
122}
123
124impl ToolCall {
125 pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
127 Self {
128 id: Uuid::new_v4().to_string(),
129 name: name.into(),
130 args,
131 }
132 }
133
134 pub fn with_id(
136 id: impl Into<String>,
137 name: impl Into<String>,
138 args: serde_json::Value,
139 ) -> Self {
140 Self {
141 id: id.into(),
142 name: name.into(),
143 args,
144 }
145 }
146
147 pub fn id(&self) -> &str {
149 &self.id
150 }
151
152 pub fn name(&self) -> &str {
154 &self.name
155 }
156
157 pub fn args(&self) -> &serde_json::Value {
159 &self.args
160 }
161}
162
163#[cfg_attr(feature = "specta", derive(Type))]
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166#[serde(untagged)]
167pub enum MessageContent {
168 Text(String),
170 Parts(Vec<ContentPart>),
172}
173
174impl MessageContent {
175 pub fn as_text(&self) -> String {
177 match self {
178 MessageContent::Text(s) => s.clone(),
179 MessageContent::Parts(parts) => parts
180 .iter()
181 .filter_map(|p| match p {
182 ContentPart::Text { text } => Some(text.as_str()),
183 _ => None,
184 })
185 .collect::<Vec<_>>()
186 .join(" "),
187 }
188 }
189
190 pub fn has_images(&self) -> bool {
192 match self {
193 MessageContent::Text(_) => false,
194 MessageContent::Parts(parts) => {
195 parts.iter().any(|p| matches!(p, ContentPart::Image { .. }))
196 }
197 }
198 }
199
200 pub fn parts(&self) -> Vec<ContentPart> {
202 match self {
203 MessageContent::Text(s) => vec![ContentPart::Text { text: s.clone() }],
204 MessageContent::Parts(parts) => parts.clone(),
205 }
206 }
207}
208
209impl From<String> for MessageContent {
210 fn from(s: String) -> Self {
211 MessageContent::Text(s)
212 }
213}
214
215impl From<&str> for MessageContent {
216 fn from(s: &str) -> Self {
217 MessageContent::Text(s.to_string())
218 }
219}
220
221impl From<Vec<ContentPart>> for MessageContent {
222 fn from(parts: Vec<ContentPart>) -> Self {
223 MessageContent::Parts(parts)
224 }
225}
226
227#[cfg_attr(feature = "specta", derive(Type))]
233#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
234pub struct HumanMessage {
235 content: MessageContent,
237 id: Option<String>,
239 #[serde(default)]
241 additional_kwargs: HashMap<String, serde_json::Value>,
242}
243
244impl HumanMessage {
245 pub fn new(content: impl Into<String>) -> Self {
247 Self {
248 content: MessageContent::Text(content.into()),
249 id: Some(Uuid::new_v4().to_string()),
250 additional_kwargs: HashMap::new(),
251 }
252 }
253
254 pub fn with_id(id: impl Into<String>, content: impl Into<String>) -> Self {
258 Self {
259 content: MessageContent::Text(content.into()),
260 id: Some(id.into()),
261 additional_kwargs: HashMap::new(),
262 }
263 }
264
265 pub fn with_content(parts: Vec<ContentPart>) -> Self {
283 Self {
284 content: MessageContent::Parts(parts),
285 id: Some(Uuid::new_v4().to_string()),
286 additional_kwargs: HashMap::new(),
287 }
288 }
289
290 pub fn with_id_and_content(id: impl Into<String>, parts: Vec<ContentPart>) -> Self {
294 Self {
295 content: MessageContent::Parts(parts),
296 id: Some(id.into()),
297 additional_kwargs: HashMap::new(),
298 }
299 }
300
301 pub fn with_image_url(text: impl Into<String>, url: impl Into<String>) -> Self {
303 Self::with_content(vec![
304 ContentPart::Text { text: text.into() },
305 ContentPart::Image {
306 source: ImageSource::Url { url: url.into() },
307 detail: None,
308 },
309 ])
310 }
311
312 pub fn with_image_base64(
314 text: impl Into<String>,
315 media_type: impl Into<String>,
316 data: impl Into<String>,
317 ) -> Self {
318 Self::with_content(vec![
319 ContentPart::Text { text: text.into() },
320 ContentPart::Image {
321 source: ImageSource::Base64 {
322 media_type: media_type.into(),
323 data: data.into(),
324 },
325 detail: None,
326 },
327 ])
328 }
329
330 pub fn content(&self) -> &str {
334 match &self.content {
335 MessageContent::Text(s) => s,
336 MessageContent::Parts(_) => "",
337 }
338 }
339
340 pub fn message_content(&self) -> &MessageContent {
342 &self.content
343 }
344
345 pub fn has_images(&self) -> bool {
347 self.content.has_images()
348 }
349
350 pub fn id(&self) -> Option<&str> {
352 self.id.as_deref()
353 }
354}
355
356#[cfg_attr(feature = "specta", derive(Type))]
358#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
359pub struct SystemMessage {
360 content: String,
362 id: Option<String>,
364 #[serde(default)]
366 additional_kwargs: HashMap<String, serde_json::Value>,
367}
368
369impl SystemMessage {
370 pub fn new(content: impl Into<String>) -> Self {
372 Self {
373 content: content.into(),
374 id: Some(Uuid::new_v4().to_string()),
375 additional_kwargs: HashMap::new(),
376 }
377 }
378
379 pub fn with_id(id: impl Into<String>, content: impl Into<String>) -> Self {
383 Self {
384 content: content.into(),
385 id: Some(id.into()),
386 additional_kwargs: HashMap::new(),
387 }
388 }
389
390 pub fn content(&self) -> &str {
392 &self.content
393 }
394
395 pub fn id(&self) -> Option<&str> {
397 self.id.as_deref()
398 }
399}
400
401#[cfg_attr(feature = "specta", derive(Type))]
403#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
404pub struct AIMessage {
405 content: String,
407 id: Option<String>,
409 #[serde(default)]
411 tool_calls: Vec<ToolCall>,
412 #[serde(default)]
414 additional_kwargs: HashMap<String, serde_json::Value>,
415}
416
417impl AIMessage {
418 pub fn new(content: impl Into<String>) -> Self {
420 Self {
421 content: content.into(),
422 id: Some(Uuid::new_v4().to_string()),
423 tool_calls: Vec::new(),
424 additional_kwargs: HashMap::new(),
425 }
426 }
427
428 pub fn with_id(id: impl Into<String>, content: impl Into<String>) -> Self {
432 Self {
433 content: content.into(),
434 id: Some(id.into()),
435 tool_calls: Vec::new(),
436 additional_kwargs: HashMap::new(),
437 }
438 }
439
440 pub fn with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
442 Self {
443 content: content.into(),
444 id: Some(Uuid::new_v4().to_string()),
445 tool_calls,
446 additional_kwargs: HashMap::new(),
447 }
448 }
449
450 pub fn with_id_and_tool_calls(
454 id: impl Into<String>,
455 content: impl Into<String>,
456 tool_calls: Vec<ToolCall>,
457 ) -> Self {
458 Self {
459 content: content.into(),
460 id: Some(id.into()),
461 tool_calls,
462 additional_kwargs: HashMap::new(),
463 }
464 }
465
466 pub fn content(&self) -> &str {
468 &self.content
469 }
470
471 pub fn id(&self) -> Option<&str> {
473 self.id.as_deref()
474 }
475
476 pub fn tool_calls(&self) -> &[ToolCall] {
478 &self.tool_calls
479 }
480
481 pub fn with_annotations<T: Serialize>(mut self, annotations: Vec<T>) -> Self {
484 if let Ok(value) = serde_json::to_value(&annotations) {
485 self.additional_kwargs
486 .insert("annotations".to_string(), value);
487 }
488 self
489 }
490
491 pub fn annotations(&self) -> Option<&serde_json::Value> {
493 self.additional_kwargs.get("annotations")
494 }
495
496 pub fn additional_kwargs(&self) -> &HashMap<String, serde_json::Value> {
498 &self.additional_kwargs
499 }
500}
501
502#[cfg_attr(feature = "specta", derive(Type))]
504#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
505pub struct ToolMessage {
506 content: String,
508 tool_call_id: String,
510 id: Option<String>,
512 #[serde(default)]
514 additional_kwargs: HashMap<String, serde_json::Value>,
515}
516
517impl ToolMessage {
518 pub fn new(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
520 Self {
521 content: content.into(),
522 tool_call_id: tool_call_id.into(),
523 id: Some(Uuid::new_v4().to_string()),
524 additional_kwargs: HashMap::new(),
525 }
526 }
527
528 pub fn with_id(
532 id: impl Into<String>,
533 content: impl Into<String>,
534 tool_call_id: impl Into<String>,
535 ) -> Self {
536 Self {
537 content: content.into(),
538 tool_call_id: tool_call_id.into(),
539 id: Some(id.into()),
540 additional_kwargs: HashMap::new(),
541 }
542 }
543
544 pub fn content(&self) -> &str {
546 &self.content
547 }
548
549 pub fn tool_call_id(&self) -> &str {
551 &self.tool_call_id
552 }
553
554 pub fn id(&self) -> Option<&str> {
556 self.id.as_deref()
557 }
558}
559
560#[cfg_attr(feature = "specta", derive(Type))]
562#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
563#[serde(tag = "type")]
564pub enum BaseMessage {
565 Human(HumanMessage),
567 System(SystemMessage),
569 AI(AIMessage),
571 Tool(ToolMessage),
573}
574
575impl BaseMessage {
576 pub fn content(&self) -> &str {
578 match self {
579 BaseMessage::Human(m) => m.content(),
580 BaseMessage::System(m) => m.content(),
581 BaseMessage::AI(m) => m.content(),
582 BaseMessage::Tool(m) => m.content(),
583 }
584 }
585
586 pub fn id(&self) -> Option<&str> {
588 match self {
589 BaseMessage::Human(m) => m.id(),
590 BaseMessage::System(m) => m.id(),
591 BaseMessage::AI(m) => m.id(),
592 BaseMessage::Tool(m) => m.id(),
593 }
594 }
595
596 pub fn tool_calls(&self) -> &[ToolCall] {
598 match self {
599 BaseMessage::AI(m) => m.tool_calls(),
600 _ => &[],
601 }
602 }
603
604 pub fn message_type(&self) -> &'static str {
606 match self {
607 BaseMessage::Human(_) => "human",
608 BaseMessage::System(_) => "system",
609 BaseMessage::AI(_) => "ai",
610 BaseMessage::Tool(_) => "tool",
611 }
612 }
613}
614
615impl From<HumanMessage> for BaseMessage {
616 fn from(msg: HumanMessage) -> Self {
617 BaseMessage::Human(msg)
618 }
619}
620
621impl From<SystemMessage> for BaseMessage {
622 fn from(msg: SystemMessage) -> Self {
623 BaseMessage::System(msg)
624 }
625}
626
627impl From<AIMessage> for BaseMessage {
628 fn from(msg: AIMessage) -> Self {
629 BaseMessage::AI(msg)
630 }
631}
632
633impl From<ToolMessage> for BaseMessage {
634 fn from(msg: ToolMessage) -> Self {
635 BaseMessage::Tool(msg)
636 }
637}
638
639pub trait HasId {
642 fn get_id(&self) -> Option<&str>;
644}
645
646impl HasId for BaseMessage {
647 fn get_id(&self) -> Option<&str> {
648 self.id()
649 }
650}
651
652pub type AnyMessage = BaseMessage;
655
656impl BaseMessage {
657 pub fn pretty_print(&self) {
660 let (role, content) = match self {
661 BaseMessage::Human(m) => ("Human", m.content()),
662 BaseMessage::System(m) => ("System", m.content()),
663 BaseMessage::AI(m) => {
664 let tool_calls = m.tool_calls();
665 if tool_calls.is_empty() {
666 ("AI", m.content())
667 } else {
668 println!(
669 "================================== AI Message =================================="
670 );
671 if !m.content().is_empty() {
672 println!("{}", m.content());
673 }
674 for tc in tool_calls {
675 println!("Tool Call: {} ({})", tc.name(), tc.id());
676 println!(" Args: {}", tc.args());
677 }
678 return;
679 }
680 }
681 BaseMessage::Tool(m) => {
682 println!(
683 "================================= Tool Message ================================="
684 );
685 println!("[{}] {}", m.tool_call_id(), m.content());
686 return;
687 }
688 };
689
690 let header = format!("=== {} Message ===", role);
691 let padding = (80 - header.len()) / 2;
692 println!(
693 "{:=>padding$}{}{:=>padding$}",
694 "",
695 header,
696 "",
697 padding = padding
698 );
699 println!("{}", content);
700 }
701}