use serde::{Deserialize, Serialize};
use crate::task::{ContextId, TaskId};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MessageId(pub String);
impl MessageId {
#[must_use]
pub fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
}
impl std::fmt::Display for MessageId {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl From<String> for MessageId {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for MessageId {
fn from(s: &str) -> Self {
Self(s.to_owned())
}
}
impl AsRef<str> for MessageId {
fn as_ref(&self) -> &str {
&self.0
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MessageRole {
#[serde(rename = "ROLE_UNSPECIFIED", alias = "unspecified")]
Unspecified,
#[serde(rename = "ROLE_USER", alias = "user")]
User,
#[serde(rename = "ROLE_AGENT", alias = "agent")]
Agent,
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Unspecified => "ROLE_UNSPECIFIED",
Self::User => "ROLE_USER",
Self::Agent => "ROLE_AGENT",
};
f.write_str(s)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Message {
#[serde(rename = "messageId")]
pub id: MessageId,
pub role: MessageRole,
pub parts: Vec<Part>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task_id: Option<TaskId>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_id: Option<ContextId>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reference_task_ids: Option<Vec<TaskId>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub extensions: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Part {
#[serde(flatten)]
pub content: PartContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub filename: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(alias = "mediaType")]
pub media_type: Option<String>,
}
#[allow(clippy::too_many_lines)]
impl<'de> serde::Deserialize<'de> for Part {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, MapAccess, Visitor};
#[derive(Debug)]
enum Field {
Text,
Raw,
Url,
Data,
Metadata,
Filename,
MediaType,
Unknown,
}
impl<'de> serde::Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct FieldVisitor;
impl serde::de::Visitor<'_> for FieldVisitor {
type Value = Field;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a Part field name")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Field, E> {
Ok(match v {
"text" => Field::Text,
"raw" => Field::Raw,
"url" => Field::Url,
"data" => Field::Data,
"metadata" => Field::Metadata,
"filename" => Field::Filename,
"mediaType" | "media_type" => Field::MediaType,
_ => Field::Unknown,
})
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct PartVisitor;
impl<'de> Visitor<'de> for PartVisitor {
type Value = Part;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a Part object with text, raw, url, or data content")
}
fn visit_map<A>(self, mut map: A) -> Result<Part, A::Error>
where
A: MapAccess<'de>,
{
let mut text: Option<String> = None;
let mut raw: Option<String> = None;
let mut url: Option<String> = None;
let mut data: Option<serde_json::Value> = None;
let mut metadata: Option<serde_json::Value> = None;
let mut filename: Option<String> = None;
let mut media_type: Option<String> = None;
while let Some(key) = map.next_key::<Field>()? {
match key {
Field::Text => {
if text.is_some() {
return Err(de::Error::duplicate_field("text"));
}
text = Some(map.next_value()?);
}
Field::Raw => {
if raw.is_some() {
return Err(de::Error::duplicate_field("raw"));
}
raw = Some(map.next_value()?);
}
Field::Url => {
if url.is_some() {
return Err(de::Error::duplicate_field("url"));
}
url = Some(map.next_value()?);
}
Field::Data => {
if data.is_some() {
return Err(de::Error::duplicate_field("data"));
}
data = Some(map.next_value()?);
}
Field::Metadata => {
if metadata.is_some() {
return Err(de::Error::duplicate_field("metadata"));
}
metadata = Some(map.next_value()?);
}
Field::Filename => {
if filename.is_some() {
return Err(de::Error::duplicate_field("filename"));
}
filename = Some(map.next_value()?);
}
Field::MediaType => {
if media_type.is_some() {
return Err(de::Error::duplicate_field("mediaType"));
}
media_type = Some(map.next_value()?);
}
Field::Unknown => {
let _ = map.next_value::<de::IgnoredAny>()?;
}
}
}
let content = if let Some(t) = text {
PartContent::Text(t)
} else if let Some(r) = raw {
PartContent::Raw(r)
} else if let Some(u) = url {
PartContent::Url(u)
} else if let Some(d) = data {
PartContent::Data(d)
} else {
return Err(de::Error::custom(
"Part must contain one of: text, raw, url, data",
));
};
Ok(Part {
content,
metadata,
filename,
media_type,
})
}
}
deserializer.deserialize_map(PartVisitor)
}
}
impl Part {
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
Self {
content: PartContent::Text(text.into()),
metadata: None,
filename: None,
media_type: None,
}
}
#[must_use]
pub fn raw(raw: impl Into<String>) -> Self {
Self {
content: PartContent::Raw(raw.into()),
metadata: None,
filename: None,
media_type: None,
}
}
#[must_use]
pub fn url(url: impl Into<String>) -> Self {
Self {
content: PartContent::Url(url.into()),
metadata: None,
filename: None,
media_type: None,
}
}
#[must_use]
pub const fn data(data: serde_json::Value) -> Self {
Self {
content: PartContent::Data(data),
metadata: None,
filename: None,
media_type: None,
}
}
#[must_use]
pub fn with_filename(mut self, filename: impl Into<String>) -> Self {
self.filename = Some(filename.into());
self
}
#[must_use]
pub fn with_media_type(mut self, media_type: impl Into<String>) -> Self {
self.media_type = Some(media_type.into());
self
}
#[must_use]
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
#[must_use]
pub fn text_content(&self) -> Option<&str> {
match &self.content {
PartContent::Text(text) => Some(text),
_ => None,
}
}
#[must_use]
pub fn file_bytes(bytes: impl Into<String>) -> Self {
Self::raw(bytes)
}
#[must_use]
pub fn file_uri(uri: impl Into<String>) -> Self {
Self::url(uri)
}
#[must_use]
pub fn file(file: FileContent) -> Self {
let mut part = if let Some(bytes) = file.bytes {
Self::raw(bytes)
} else if let Some(uri) = file.uri {
Self::url(uri)
} else {
Self::raw("")
};
part.filename = file.name;
part.media_type = file.mime_type;
part
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum PartContent {
Text(String),
Raw(String),
Url(String),
Data(serde_json::Value),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FileContent {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bytes: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub uri: Option<String>,
}
impl FileContent {
#[must_use]
pub fn from_bytes(bytes: impl Into<String>) -> Self {
Self {
name: None,
mime_type: None,
bytes: Some(bytes.into()),
uri: None,
}
}
#[must_use]
pub fn from_uri(uri: impl Into<String>) -> Self {
Self {
name: None,
mime_type: None,
bytes: None,
uri: Some(uri.into()),
}
}
#[must_use]
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
#[must_use]
pub fn with_mime_type(mut self, mime_type: impl Into<String>) -> Self {
self.mime_type = Some(mime_type.into());
self
}
pub const fn validate(&self) -> Result<(), &'static str> {
if self.bytes.is_none() && self.uri.is_none() {
Err("FileContent must have at least one of 'bytes' or 'uri' set")
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_message() -> Message {
Message {
id: MessageId::new("msg-1"),
role: MessageRole::User,
parts: vec![Part::text("Hello")],
task_id: None,
context_id: None,
reference_task_ids: None,
extensions: None,
metadata: None,
}
}
#[test]
fn message_roundtrip() {
let msg = make_message();
let json = serde_json::to_string(&msg).expect("serialize");
assert!(json.contains("\"messageId\":\"msg-1\""));
assert!(json.contains("\"role\":\"ROLE_USER\""));
let back: Message = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.id, MessageId::new("msg-1"));
assert_eq!(back.role, MessageRole::User);
}
#[test]
fn role_serializes_as_proto_names() {
assert_eq!(
serde_json::to_string(&MessageRole::User).unwrap(),
"\"ROLE_USER\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Agent).unwrap(),
"\"ROLE_AGENT\""
);
assert_eq!(
serde_json::to_string(&MessageRole::Unspecified).unwrap(),
"\"ROLE_UNSPECIFIED\""
);
}
#[test]
fn role_accepts_legacy_lowercase() {
let back: MessageRole = serde_json::from_str("\"user\"").unwrap();
assert_eq!(back, MessageRole::User);
let back: MessageRole = serde_json::from_str("\"agent\"").unwrap();
assert_eq!(back, MessageRole::Agent);
}
#[test]
fn text_part_v1_format() {
let part = Part::text("hello world");
let json = serde_json::to_string(&part).expect("serialize");
assert!(
json.contains("\"text\":\"hello world\""),
"should have text field: {json}"
);
assert!(
!json.contains("\"type\""),
"v1.0 should not have type field: {json}"
);
let back: Part = serde_json::from_str(&json).expect("deserialize");
assert!(matches!(back.content, PartContent::Text(ref t) if t == "hello world"));
}
#[test]
fn raw_part_v1_format() {
let part = Part::raw("aGVsbG8=")
.with_filename("test.png")
.with_media_type("image/png");
let json = serde_json::to_string(&part).expect("serialize");
assert!(json.contains("\"raw\":\"aGVsbG8=\""));
assert!(json.contains("\"filename\":\"test.png\""));
assert!(json.contains("\"mediaType\":\"image/png\""));
assert!(!json.contains("\"type\""));
let back: Part = serde_json::from_str(&json).expect("deserialize");
assert!(matches!(back.content, PartContent::Raw(ref r) if r == "aGVsbG8="));
assert_eq!(back.filename.as_deref(), Some("test.png"));
assert_eq!(back.media_type.as_deref(), Some("image/png"));
}
#[test]
fn url_part_v1_format() {
let part = Part::url("https://example.com/file.pdf")
.with_filename("file.pdf")
.with_media_type("application/pdf");
let json = serde_json::to_string(&part).expect("serialize");
assert!(json.contains("\"url\":\"https://example.com/file.pdf\""));
assert!(json.contains("\"filename\":\"file.pdf\""));
assert!(!json.contains("\"type\""));
let back: Part = serde_json::from_str(&json).expect("deserialize");
assert!(
matches!(back.content, PartContent::Url(ref u) if u == "https://example.com/file.pdf")
);
}
#[test]
fn data_part_v1_format() {
let part = Part::data(serde_json::json!({"key": "value"}));
let json = serde_json::to_string(&part).expect("serialize");
assert!(json.contains("\"data\""));
assert!(!json.contains("\"type\""));
let back: Part = serde_json::from_str(&json).expect("deserialize");
match &back.content {
PartContent::Data(data) => assert_eq!(data["key"], "value"),
_ => panic!("expected Data variant"),
}
}
#[test]
fn none_fields_omitted() {
let msg = make_message();
let json = serde_json::to_string(&msg).expect("serialize");
assert!(
!json.contains("\"taskId\""),
"taskId should be omitted: {json}"
);
assert!(
!json.contains("\"metadata\""),
"metadata should be omitted: {json}"
);
}
#[test]
fn message_role_display_trait() {
assert_eq!(MessageRole::User.to_string(), "ROLE_USER");
assert_eq!(MessageRole::Agent.to_string(), "ROLE_AGENT");
assert_eq!(MessageRole::Unspecified.to_string(), "ROLE_UNSPECIFIED");
}
#[test]
fn message_with_reference_task_ids() {
use crate::task::TaskId;
let msg = Message {
id: MessageId::new("msg-ref"),
role: MessageRole::User,
parts: vec![Part::text("check these tasks")],
task_id: None,
context_id: None,
reference_task_ids: Some(vec![TaskId::new("task-100"), TaskId::new("task-200")]),
extensions: None,
metadata: None,
};
let json = serde_json::to_string(&msg).expect("serialize");
assert!(json.contains("\"referenceTaskIds\""));
assert!(json.contains("\"task-100\""));
let back: Message = serde_json::from_str(&json).expect("deserialize");
let refs = back
.reference_task_ids
.expect("should have reference_task_ids");
assert_eq!(refs.len(), 2);
}
#[test]
fn backward_compat_file_bytes_constructor() {
let part = Part::file_bytes("aGVsbG8=");
assert!(matches!(part.content, PartContent::Raw(_)));
}
#[test]
fn backward_compat_file_uri_constructor() {
let part = Part::file_uri("https://example.com/file.pdf");
assert!(matches!(part.content, PartContent::Url(_)));
}
#[test]
fn backward_compat_file_constructor() {
let fc = FileContent::from_bytes("aGVsbG8=")
.with_name("test.png")
.with_mime_type("image/png");
let part = Part::file(fc);
assert!(matches!(part.content, PartContent::Raw(ref r) if r == "aGVsbG8="));
assert_eq!(part.filename.as_deref(), Some("test.png"));
assert_eq!(part.media_type.as_deref(), Some("image/png"));
}
#[test]
fn message_id_display() {
let id = MessageId::new("msg-42");
assert_eq!(id.to_string(), "msg-42");
}
#[test]
fn message_id_as_ref() {
let id = MessageId::new("ref-test");
assert_eq!(id.as_ref(), "ref-test");
}
#[test]
fn message_id_from_impls() {
let from_str: MessageId = "str-id".into();
assert_eq!(from_str, MessageId::new("str-id"));
let from_string: MessageId = String::from("string-id").into();
assert_eq!(from_string, MessageId::new("string-id"));
}
#[test]
fn part_text_has_no_metadata() {
let p = Part::text("hi");
assert!(p.metadata.is_none());
assert!(p.filename.is_none());
assert!(p.media_type.is_none());
}
}