use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceContent {
pub uri: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blob: Option<String>,
}
impl ResourceContent {
pub fn text(uri: impl Into<String>, text: impl Into<String>) -> Self {
Self {
uri: uri.into(),
mime_type: Some("text/plain".to_string()),
text: Some(text.into()),
blob: None,
}
}
pub fn json(uri: impl Into<String>, text: impl Into<String>) -> Self {
Self {
uri: uri.into(),
mime_type: Some("application/json".to_string()),
text: Some(text.into()),
blob: None,
}
}
pub fn binary(
uri: impl Into<String>,
mime_type: impl Into<String>,
blob: impl Into<String>,
) -> Self {
Self {
uri: uri.into(),
mime_type: Some(mime_type.into()),
text: None,
blob: Some(blob.into()),
}
}
pub fn is_text(&self) -> bool {
self.text.is_some()
}
pub fn is_binary(&self) -> bool {
self.blob.is_some()
}
pub fn as_text(&self) -> Option<&str> {
self.text.as_deref()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptResult {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub messages: Vec<PromptMessage>,
}
impl PromptResult {
pub fn new(messages: Vec<PromptMessage>) -> Self {
Self {
description: None,
messages,
}
}
pub fn with_description(description: impl Into<String>, messages: Vec<PromptMessage>) -> Self {
Self {
description: Some(description.into()),
messages,
}
}
pub fn first_message(&self) -> Option<&PromptMessage> {
self.messages.first()
}
pub fn user_messages(&self) -> impl Iterator<Item = &PromptMessage> {
self.messages.iter().filter(|m| m.role == "user")
}
pub fn assistant_messages(&self) -> impl Iterator<Item = &PromptMessage> {
self.messages.iter().filter(|m| m.role == "assistant")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptMessage {
pub role: String,
pub content: PromptContent,
}
impl PromptMessage {
pub fn user(text: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: PromptContent::Text { text: text.into() },
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: PromptContent::Text { text: text.into() },
}
}
pub fn system(text: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: PromptContent::Text { text: text.into() },
}
}
pub fn as_text(&self) -> Option<&str> {
match &self.content {
PromptContent::Text { text } => Some(text),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum PromptContent {
Text {
text: String,
},
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: String,
},
Resource {
resource: EmbeddedResource,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedResource {
pub uri: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "mimeType")]
pub mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blob: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resource_content_text() {
let content = ResourceContent::text("file://test.txt", "Hello, world!");
assert!(content.is_text());
assert!(!content.is_binary());
assert_eq!(content.as_text(), Some("Hello, world!"));
}
#[test]
fn test_resource_content_json() {
let content = ResourceContent::json("file://data.json", r#"{"key": "value"}"#);
assert!(content.is_text());
assert_eq!(content.mime_type, Some("application/json".to_string()));
}
#[test]
fn test_prompt_message_user() {
let msg = PromptMessage::user("Hello!");
assert_eq!(msg.role, "user");
assert_eq!(msg.as_text(), Some("Hello!"));
}
#[test]
fn test_prompt_result() {
let result = PromptResult::with_description(
"Test prompt",
vec![
PromptMessage::user("What is 2+2?"),
PromptMessage::assistant("4"),
],
);
assert_eq!(result.description, Some("Test prompt".to_string()));
assert_eq!(result.messages.len(), 2);
assert_eq!(result.user_messages().count(), 1);
assert_eq!(result.assistant_messages().count(), 1);
}
#[test]
fn test_prompt_content_serialization() {
let content = PromptContent::Text {
text: "Hello".to_string(),
};
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains(r#""type":"text""#));
assert!(json.contains(r#""text":"Hello""#));
}
}