use serde::{Deserialize, Serialize};
use uuid::Uuid;
pub type NodeId = Uuid;
pub type NodeSessionId = Uuid;
pub type NodeTaskId = Uuid;
pub type NodeLeaseId = Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeModelCapability {
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeCapabilities {
pub runtime: String,
pub models: Vec<NodeModelCapability>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeRegisterRequest {
pub protocol_version: String,
pub client_instance_id: String,
pub display_name: String,
pub registration_token: String,
pub capabilities: NodeCapabilities,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeRegisterResponse {
pub protocol_version: String,
pub node_id: NodeId,
pub session_id: NodeSessionId,
pub session_token: String,
pub heartbeat_interval_secs: u64,
pub poll_timeout_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeHeartbeatRequest {
pub protocol_version: String,
pub node_id: NodeId,
pub session_id: NodeSessionId,
pub accepted_models: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeHeartbeatResponse {
pub protocol_version: String,
pub accepted: bool,
pub node_status: String,
pub server_failure_count: u32,
pub failure_threshold: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodePollRequest {
pub protocol_version: String,
pub node_id: NodeId,
pub session_id: NodeSessionId,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodePollResponse {
pub protocol_version: String,
pub task: Option<NodeTaskEnvelope>,
pub retry_after_ms: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTaskEnvelope {
pub task_id: NodeTaskId,
pub lease_id: NodeLeaseId,
pub model: String,
pub deadline_unix_ms: i64,
pub complete_grace_until_unix_ms: i64,
pub payload: NodeTaskPayload,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTaskPayload {
pub request_id: Uuid,
pub chat: ChatCompletionRequest,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTaskCompleteRequest {
pub protocol_version: String,
pub node_id: NodeId,
pub session_id: NodeSessionId,
pub task_id: NodeTaskId,
pub lease_id: NodeLeaseId,
pub result: NodeTaskResult,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "status")]
pub enum NodeTaskResult {
Succeeded {
response: ChatCompletionResponse,
},
Failed {
code: String,
message: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeTaskCompleteResponse {
pub action: NodeTaskCompleteAction,
pub task_status: String,
pub node_status: String,
pub server_failure_count: u32,
pub failure_threshold: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum NodeTaskCompleteAction {
Succeeded,
Requeued,
Failed,
Expired,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
#[default]
User,
Assistant,
Tool,
}
impl MessageRole {
pub fn as_str(&self) -> &'static str {
match self {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
}
}
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
impl Message {
#[allow(dead_code)] pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
#[allow(dead_code)] pub fn system(content: impl Into<String>) -> Self {
Self::new(MessageRole::System, content)
}
#[allow(dead_code)] pub fn user(content: impl Into<String>) -> Self {
Self::new(MessageRole::User, content)
}
#[allow(dead_code)] pub fn assistant(content: impl Into<String>) -> Self {
Self::new(MessageRole::Assistant, content)
}
#[allow(dead_code)] pub fn tool(content: impl Into<String>) -> Self {
Self::new(MessageRole::Tool, content)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
impl ChatCompletionRequest {
#[allow(dead_code)] pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
stream: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stop: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChoice {
pub index: u32,
pub message: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageContent {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_role_serialize() {
let role = MessageRole::Assistant;
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, "\"assistant\"");
}
#[test]
fn test_message_role_deserialize() {
let json = "\"system\"";
let role: MessageRole = serde_json::from_str(json).unwrap();
assert_eq!(role, MessageRole::System);
}
#[test]
fn test_chat_completion_request_serialize() {
let req = ChatCompletionRequest::new("deepseek-chat", vec![Message::user("Hello")]);
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"model\":\"deepseek-chat\""));
assert!(json.contains("\"role\":\"user\""));
}
#[test]
fn test_node_model_capability_serialize() {
let capability = NodeModelCapability {
model: "deepseek-chat".to_string(),
};
let json = serde_json::to_string(&capability).unwrap();
let parsed: NodeModelCapability = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model, "deepseek-chat");
}
#[test]
fn test_node_model_capability_deserialize() {
let json = r#"{"model":"llama3"}"#;
let capability: NodeModelCapability = serde_json::from_str(json).unwrap();
assert_eq!(capability.model, "llama3");
}
#[test]
fn test_node_capabilities_serialize() {
let caps = NodeCapabilities {
runtime: "ollama".to_string(),
models: vec![
NodeModelCapability {
model: "deepseek-chat".to_string(),
},
NodeModelCapability {
model: "llama3".to_string(),
},
],
};
let json = serde_json::to_string(&caps).unwrap();
let parsed: NodeCapabilities = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.runtime, "ollama");
assert_eq!(parsed.models.len(), 2);
assert_eq!(parsed.models[0].model, "deepseek-chat");
assert_eq!(parsed.models[1].model, "llama3");
}
#[test]
fn test_node_capabilities_deserialize() {
let json = r#"{
"runtime": "ollama",
"models": [
{"model": "deepseek-chat"},
{"model": "llama3"}
]
}"#;
let caps: NodeCapabilities = serde_json::from_str(json).unwrap();
assert_eq!(caps.runtime, "ollama");
assert_eq!(caps.models.len(), 2);
assert_eq!(caps.models[0].model, "deepseek-chat");
assert_eq!(caps.models[1].model, "llama3");
}
#[test]
fn test_node_register_request_serialize() {
let req = NodeRegisterRequest {
protocol_version: "node.v1".to_string(),
client_instance_id: "test-instance-001".to_string(),
display_name: "Test Node".to_string(),
registration_token: "secret-token".to_string(),
capabilities: NodeCapabilities {
runtime: "ollama".to_string(),
models: vec![NodeModelCapability {
model: "deepseek-chat".to_string(),
}],
},
};
let json = serde_json::to_string(&req).unwrap();
let parsed: NodeRegisterRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.protocol_version, "node.v1");
assert_eq!(parsed.client_instance_id, "test-instance-001");
assert_eq!(parsed.display_name, "Test Node");
assert_eq!(parsed.capabilities.runtime, "ollama");
assert_eq!(parsed.capabilities.models.len(), 1);
}
#[test]
fn test_node_register_response_deserialize() {
let node_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let json = serde_json::json!({
"protocol_version": "node.v1",
"node_id": node_id.to_string(),
"session_id": session_id.to_string(),
"session_token": "test-session-token",
"heartbeat_interval_secs": 30,
"poll_timeout_secs": 60
});
let resp: NodeRegisterResponse = serde_json::from_value(json).unwrap();
assert_eq!(resp.protocol_version, "node.v1");
assert_eq!(resp.node_id, node_id);
assert_eq!(resp.session_id, session_id);
assert_eq!(resp.session_token, "test-session-token");
assert_eq!(resp.heartbeat_interval_secs, 30);
assert_eq!(resp.poll_timeout_secs, 60);
}
#[test]
fn test_node_heartbeat_request_serialize() {
let node_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let req = NodeHeartbeatRequest {
protocol_version: "node.v1".to_string(),
node_id,
session_id,
accepted_models: vec!["deepseek-chat".to_string(), "llama3".to_string()],
};
let json = serde_json::to_string(&req).unwrap();
let parsed: NodeHeartbeatRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.protocol_version, "node.v1");
assert_eq!(parsed.node_id, node_id);
assert_eq!(parsed.session_id, session_id);
assert_eq!(parsed.accepted_models.len(), 2);
assert!(parsed
.accepted_models
.contains(&"deepseek-chat".to_string()));
assert!(parsed.accepted_models.contains(&"llama3".to_string()));
}
#[test]
fn test_node_heartbeat_response_deserialize() {
let json = serde_json::json!({
"protocol_version": "node.v1",
"accepted": true,
"node_status": "online",
"server_failure_count": 0,
"failure_threshold": 3
});
let resp: NodeHeartbeatResponse = serde_json::from_value(json).unwrap();
assert_eq!(resp.protocol_version, "node.v1");
assert!(resp.accepted);
assert_eq!(resp.node_status, "online");
assert_eq!(resp.server_failure_count, 0);
assert_eq!(resp.failure_threshold, 3);
}
#[test]
fn test_node_heartbeat_response_excluded() {
let json = serde_json::json!({
"protocol_version": "node.v1",
"accepted": true,
"node_status": "excluded",
"server_failure_count": 3,
"failure_threshold": 3
});
let resp: NodeHeartbeatResponse = serde_json::from_value(json).unwrap();
assert!(resp.accepted);
assert_eq!(resp.node_status, "excluded");
assert_eq!(resp.server_failure_count, 3);
}
#[test]
fn test_node_poll_request_serialize() {
let node_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let req = NodePollRequest {
protocol_version: "node.v1".to_string(),
node_id,
session_id,
};
let json = serde_json::to_string(&req).unwrap();
let parsed: NodePollRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.protocol_version, "node.v1");
assert_eq!(parsed.node_id, node_id);
assert_eq!(parsed.session_id, session_id);
}
#[test]
fn test_node_poll_response_with_task() {
let task_id = Uuid::new_v4();
let lease_id = Uuid::new_v4();
let json = serde_json::json!({
"protocol_version": "node.v1",
"task": {
"task_id": task_id.to_string(),
"lease_id": lease_id.to_string(),
"model": "deepseek-chat",
"deadline_unix_ms": 1234567890123i64,
"complete_grace_until_unix_ms": 1234567950123i64,
"payload": {
"request_id": Uuid::new_v4().to_string(),
"chat": {
"model": "deepseek-chat",
"messages": [{"role": "user", "content": "Hello"}]
}
}
},
"retry_after_ms": null
});
let resp: NodePollResponse = serde_json::from_value(json).unwrap();
assert_eq!(resp.protocol_version, "node.v1");
assert!(resp.task.is_some());
let task = resp.task.unwrap();
assert_eq!(task.task_id, task_id);
assert_eq!(task.lease_id, lease_id);
assert_eq!(task.model, "deepseek-chat");
}
#[test]
fn test_node_poll_response_no_task() {
let json = serde_json::json!({
"protocol_version": "node.v1",
"task": null,
"retry_after_ms": 1000
});
let resp: NodePollResponse = serde_json::from_value(json).unwrap();
assert!(resp.task.is_none());
assert_eq!(resp.retry_after_ms, Some(1000));
}
#[test]
fn test_node_task_complete_request_succeeded() {
let node_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let task_id = Uuid::new_v4();
let lease_id = Uuid::new_v4();
let req = NodeTaskCompleteRequest {
protocol_version: "node.v1".to_string(),
node_id,
session_id,
task_id,
lease_id,
result: NodeTaskResult::Succeeded {
response: ChatCompletionResponse {
id: "resp-001".to_string(),
object: "chat.completion".to_string(),
created: 1234567890,
model: "deepseek-chat".to_string(),
choices: vec![CompletionChoice {
index: 0,
message: MessageContent {
role: "assistant".to_string(),
content: "Hello! How can I help you?".to_string(),
},
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
},
},
},
};
let json = serde_json::to_string(&req).unwrap();
let parsed: NodeTaskCompleteRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.protocol_version, "node.v1");
assert_eq!(parsed.node_id, node_id);
assert_eq!(parsed.task_id, task_id);
assert_eq!(parsed.lease_id, lease_id);
match parsed.result {
NodeTaskResult::Succeeded { response } => {
assert_eq!(response.id, "resp-001");
assert_eq!(
response.choices[0].message.content,
"Hello! How can I help you?"
);
}
NodeTaskResult::Failed { .. } => panic!("Expected Succeeded variant"),
}
}
#[test]
fn test_node_task_complete_request_failed() {
let node_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let task_id = Uuid::new_v4();
let lease_id = Uuid::new_v4();
let req = NodeTaskCompleteRequest {
protocol_version: "node.v1".to_string(),
node_id,
session_id,
task_id,
lease_id,
result: NodeTaskResult::Failed {
code: "ollama_error".to_string(),
message: "Model not found".to_string(),
},
};
let json = serde_json::to_string(&req).unwrap();
let parsed: NodeTaskCompleteRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.protocol_version, "node.v1");
assert_eq!(parsed.task_id, task_id);
match parsed.result {
NodeTaskResult::Failed { code, message } => {
assert_eq!(code, "ollama_error");
assert_eq!(message, "Model not found");
}
NodeTaskResult::Succeeded { .. } => panic!("Expected Failed variant"),
}
}
#[test]
fn test_node_task_complete_response_deserialize() {
let json = serde_json::json!({
"action": "succeeded",
"task_status": "succeeded",
"node_status": "online",
"server_failure_count": 0,
"failure_threshold": 3
});
let resp: NodeTaskCompleteResponse = serde_json::from_value(json).unwrap();
assert_eq!(resp.action, NodeTaskCompleteAction::Succeeded);
assert_eq!(resp.task_status, "succeeded");
assert_eq!(resp.node_status, "online");
}
#[test]
fn test_node_task_complete_action_enum() {
let actions = vec![
(NodeTaskCompleteAction::Succeeded, "succeeded"),
(NodeTaskCompleteAction::Requeued, "requeued"),
(NodeTaskCompleteAction::Failed, "failed"),
(NodeTaskCompleteAction::Expired, "expired"),
];
for (action, expected_str) in actions {
let json = serde_json::to_string(&action).unwrap();
assert_eq!(json, format!("\"{}\"", expected_str));
let deserialized: NodeTaskCompleteAction = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, action);
}
}
#[test]
fn test_chat_completion_response_serialize() {
let resp = ChatCompletionResponse {
id: "resp-001".to_string(),
object: "chat.completion".to_string(),
created: 1234567890,
model: "deepseek-chat".to_string(),
choices: vec![CompletionChoice {
index: 0,
message: MessageContent {
role: "assistant".to_string(),
content: "Test response".to_string(),
},
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: 5,
completion_tokens: 10,
total_tokens: 15,
},
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("\"id\":\"resp-001\""));
assert!(json.contains("\"object\":\"chat.completion\""));
assert!(json.contains("\"model\":\"deepseek-chat\""));
assert!(json.contains("\"role\":\"assistant\""));
assert!(json.contains("Test response"));
}
#[test]
fn test_chat_completion_response_deserialize() {
let json = r#"{
"id": "resp-002",
"object": "chat.completion",
"created": 1234567891,
"model": "llama3",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": "Response"},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 8, "completion_tokens": 12, "total_tokens": 20}
}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.id, "resp-002");
assert_eq!(resp.model, "llama3");
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content, "Response");
assert_eq!(resp.usage.total_tokens, 20);
}
#[test]
fn test_message_serialize() {
let msg = Message::user("Hello, world!");
let json = serde_json::to_string(&msg).unwrap();
let parsed: Message = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.role, MessageRole::User);
assert_eq!(parsed.content, "Hello, world!");
}
#[test]
fn test_message_deserialize() {
let json = r#"{"role":"system","content":"You are a helpful assistant"}"#;
let msg: Message = serde_json::from_str(json).unwrap();
assert_eq!(msg.role, MessageRole::System);
assert_eq!(msg.content, "You are a helpful assistant");
}
#[test]
fn test_message_helper_methods() {
let system_msg = Message::system("System prompt");
assert_eq!(system_msg.role, MessageRole::System);
assert_eq!(system_msg.content, "System prompt");
let user_msg = Message::user("User question");
assert_eq!(user_msg.role, MessageRole::User);
assert_eq!(user_msg.content, "User question");
let assistant_msg = Message::assistant("Assistant response");
assert_eq!(assistant_msg.role, MessageRole::Assistant);
assert_eq!(assistant_msg.content, "Assistant response");
let tool_msg = Message::tool("Tool output");
assert_eq!(tool_msg.role, MessageRole::Tool);
assert_eq!(tool_msg.content, "Tool output");
}
#[test]
fn test_message_with_unicode() {
let msg = Message::user("🚀 My Node 节点 Привет");
let json = serde_json::to_string(&msg).unwrap();
let parsed: Message = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.content, "🚀 My Node 节点 Привет");
}
#[test]
fn test_message_with_special_characters() {
let content = r#"Line1\nLine2\"Quote\\Backslash"#;
let msg = Message::user(content);
let json = serde_json::to_string(&msg).unwrap();
let parsed: Message = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.content, content);
}
#[test]
fn test_empty_accepted_models() {
let node_id = Uuid::new_v4();
let session_id = Uuid::new_v4();
let req = NodeHeartbeatRequest {
protocol_version: "node.v1".to_string(),
node_id,
session_id,
accepted_models: vec![], };
let json = serde_json::to_string(&req).unwrap();
let parsed: NodeHeartbeatRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.accepted_models.len(), 0);
}
#[test]
fn test_chat_completion_request_with_optional_fields() {
let req = ChatCompletionRequest {
model: "deepseek-chat".to_string(),
messages: vec![Message::user("Test")],
stream: Some(false),
max_tokens: Some(100),
temperature: Some(0.7),
top_p: Some(0.9),
n: Some(1),
stop: Some(vec!["\n".to_string()]),
};
let json = serde_json::to_string(&req).unwrap();
let parsed: ChatCompletionRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.max_tokens, Some(100));
assert_eq!(parsed.temperature, Some(0.7));
assert_eq!(parsed.top_p, Some(0.9));
assert_eq!(parsed.stop, Some(vec!["\n".to_string()]));
}
#[test]
fn test_node_task_complete_action_all_variants() {
let actions = vec![
NodeTaskCompleteAction::Succeeded,
NodeTaskCompleteAction::Requeued,
NodeTaskCompleteAction::Failed,
NodeTaskCompleteAction::Expired,
];
for action in actions {
let json = serde_json::to_string(&action).unwrap();
let parsed: NodeTaskCompleteAction = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, action);
}
}
}