use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Thread {
pub id: String,
pub object: String,
pub created_at: i64,
pub metadata: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
User,
Assistant,
}
pub type Annotation = serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextContent {
pub value: String,
pub annotations: Vec<Annotation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentBlock {
pub r#type: String,
pub text: TextContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThreadMessage {
pub id: String,
pub object: String,
pub created_at: i64,
pub thread_id: String,
pub role: MessageRole,
pub content: Vec<ContentBlock>,
#[serde(skip_serializing_if = "Option::is_none")]
pub run_id: Option<String>,
}
impl ThreadMessage {
pub fn new_user(id: String, thread_id: String, content: String) -> Self {
Self {
id,
object: "thread.message".to_string(),
created_at: unix_now(),
thread_id,
role: MessageRole::User,
content: vec![ContentBlock {
r#type: "text".to_string(),
text: TextContent {
value: content,
annotations: vec![],
},
}],
run_id: None,
}
}
pub fn new_assistant(id: String, thread_id: String, run_id: String, content: String) -> Self {
Self {
id,
object: "thread.message".to_string(),
created_at: unix_now(),
thread_id,
role: MessageRole::Assistant,
content: vec![ContentBlock {
r#type: "text".to_string(),
text: TextContent {
value: content,
annotations: vec![],
},
}],
run_id: Some(run_id),
}
}
pub fn text_content(&self) -> &str {
self.content
.first()
.map(|b| b.text.value.as_str())
.unwrap_or("")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RunStatus {
Queued,
InProgress,
Completed,
Cancelled,
Failed,
Expired,
}
impl RunStatus {
pub fn is_terminal(&self) -> bool {
matches!(
self,
RunStatus::Completed | RunStatus::Cancelled | RunStatus::Failed | RunStatus::Expired
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunError {
pub code: String,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Run {
pub id: String,
pub object: String,
pub created_at: i64,
pub thread_id: String,
pub status: RunStatus,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_error: Option<RunError>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RunStepType {
MessageCreation,
ToolCalls,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RunStepStatus {
InProgress,
Completed,
Failed,
Cancelled,
}
impl RunStepStatus {
pub fn is_terminal(&self) -> bool {
matches!(
self,
RunStepStatus::Completed | RunStepStatus::Failed | RunStepStatus::Cancelled
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageCreationStepDetails {
pub message_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunStep {
pub id: String,
pub object: String,
pub run_id: String,
pub thread_id: String,
pub step_type: RunStepType,
pub status: RunStepStatus,
pub created_at: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub failed_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub step_details: Option<MessageCreationStepDetails>,
}
impl RunStep {
pub fn new_message_creation(step_id: String, run_id: String, thread_id: String) -> Self {
Self {
id: step_id,
object: "thread.run.step".to_string(),
run_id,
thread_id,
step_type: RunStepType::MessageCreation,
status: RunStepStatus::InProgress,
created_at: unix_now() as u64,
completed_at: None,
failed_at: None,
error: None,
step_details: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateMessageRequest {
pub role: MessageRole,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateThreadRequest {
#[serde(default)]
pub messages: Option<Vec<CreateMessageRequest>>,
#[serde(default)]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateRunRequest {
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub instructions: Option<String>,
#[serde(default)]
pub max_tokens: Option<usize>,
#[serde(default)]
pub stream: bool,
}
pub(crate) fn unix_now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn run_status_terminal_set_is_correct() {
assert!(RunStatus::Completed.is_terminal());
assert!(RunStatus::Cancelled.is_terminal());
assert!(RunStatus::Failed.is_terminal());
assert!(RunStatus::Expired.is_terminal());
assert!(!RunStatus::Queued.is_terminal());
assert!(!RunStatus::InProgress.is_terminal());
}
#[test]
fn thread_message_new_user_sets_fields() {
let msg = ThreadMessage::new_user("msg_1".into(), "thread_1".into(), "hello".into());
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.text_content(), "hello");
assert!(msg.run_id.is_none());
}
#[test]
fn thread_message_new_assistant_sets_run_id() {
let msg = ThreadMessage::new_assistant(
"msg_2".into(),
"thread_1".into(),
"run_1".into(),
"hi!".into(),
);
assert_eq!(msg.role, MessageRole::Assistant);
assert_eq!(msg.run_id, Some("run_1".into()));
}
#[test]
fn run_status_serde_roundtrip() {
let s = serde_json::to_string(&RunStatus::InProgress).expect("serialize");
let d: RunStatus = serde_json::from_str(&s).expect("deserialize");
assert_eq!(d, RunStatus::InProgress);
}
#[test]
fn message_role_serde_lowercase() {
let json = serde_json::to_string(&MessageRole::User).expect("serialize");
assert_eq!(json, r#""user""#);
let back: MessageRole = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back, MessageRole::User);
}
}