use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::Result;
#[async_trait]
pub trait AgentProtocol: Send + Sync {
fn name(&self) -> &str;
fn chat(&self, prompt: &str) -> Result<String>;
async fn achat(&self, prompt: &str) -> Result<String>;
}
#[async_trait]
pub trait RunnableAgentProtocol: AgentProtocol {
fn run(&self, prompt: &str) -> Result<String> {
self.chat(prompt)
}
fn start(&self, prompt: &str) -> Result<String> {
self.chat(prompt)
}
async fn arun(&self, prompt: &str) -> Result<String> {
self.achat(prompt).await
}
async fn astart(&self, prompt: &str) -> Result<String> {
self.achat(prompt).await
}
}
#[async_trait]
pub trait ToolProtocol: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, args: serde_json::Value) -> Result<serde_json::Value>;
}
#[async_trait]
pub trait MemoryProtocol: Send + Sync {
async fn store(&mut self, role: &str, content: &str) -> Result<()>;
async fn history(&self) -> Result<Vec<MemoryMessage>>;
async fn clear(&mut self) -> Result<()>;
async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryMessage>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryMessage {
pub role: String,
pub content: String,
pub timestamp: u64,
pub metadata: HashMap<String, String>,
}
impl MemoryMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: role.into(),
content: content.into(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
metadata: HashMap::new(),
}
}
}
#[async_trait]
pub trait LlmProtocol: Send + Sync {
fn model(&self) -> &str;
async fn chat(&self, messages: &[LlmMessage]) -> Result<LlmResponse>;
async fn chat_with_tools(
&self,
messages: &[LlmMessage],
tools: &[ToolSchema],
) -> Result<LlmResponse>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmMessage {
pub role: String,
pub content: String,
}
impl LlmMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub content: String,
pub tool_calls: Vec<ToolCall>,
pub usage: Option<TokenUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentOSConfig {
pub url: String,
pub api_key: Option<String>,
pub agent_id: Option<String>,
pub telemetry: bool,
pub timeout: u32,
}
impl Default for AgentOSConfig {
fn default() -> Self {
Self {
url: "https://agentos.praison.ai".to_string(),
api_key: None,
agent_id: None,
telemetry: true,
timeout: 30,
}
}
}
impl AgentOSConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
..Default::default()
}
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn agent_id(mut self, id: impl Into<String>) -> Self {
self.agent_id = Some(id.into());
self
}
pub fn no_telemetry(mut self) -> Self {
self.telemetry = false;
self
}
}
#[async_trait]
pub trait AgentOSProtocol: Send + Sync {
fn config(&self) -> &AgentOSConfig;
async fn register(&self) -> Result<String>;
async fn heartbeat(&self) -> Result<()>;
async fn report_metrics(&self, metrics: AgentMetrics) -> Result<()>;
async fn get_config(&self) -> Result<serde_json::Value>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentMetrics {
pub requests: u64,
pub errors: u64,
pub avg_response_time_ms: f64,
pub tokens_used: u64,
pub custom: HashMap<String, f64>,
}
impl Default for AgentMetrics {
fn default() -> Self {
Self {
requests: 0,
errors: 0,
avg_response_time_ms: 0.0,
tokens_used: 0,
custom: HashMap::new(),
}
}
}
impl AgentMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn custom(mut self, key: impl Into<String>, value: f64) -> Self {
self.custom.insert(key.into(), value);
self
}
}
#[async_trait]
pub trait BotProtocol: Send + Sync {
fn name(&self) -> &str;
async fn on_message(&self, message: BotMessage) -> Result<BotResponse>;
async fn on_command(&self, command: &str, args: &[&str]) -> Result<BotResponse>;
async fn start(&mut self) -> Result<()>;
async fn stop(&mut self) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BotMessage {
pub id: String,
pub sender_id: String,
pub sender_name: Option<String>,
pub content: String,
pub channel_id: Option<String>,
pub timestamp: u64,
pub metadata: HashMap<String, String>,
}
impl BotMessage {
pub fn new(sender_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
sender_id: sender_id.into(),
sender_name: None,
content: content.into(),
channel_id: None,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
metadata: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BotResponse {
pub content: String,
pub reply_to: Option<String>,
pub attachments: Vec<BotAttachment>,
pub actions: Vec<BotAction>,
}
impl BotResponse {
pub fn text(content: impl Into<String>) -> Self {
Self {
content: content.into(),
reply_to: None,
attachments: Vec::new(),
actions: Vec::new(),
}
}
pub fn reply_to(mut self, message_id: impl Into<String>) -> Self {
self.reply_to = Some(message_id.into());
self
}
pub fn attachment(mut self, attachment: BotAttachment) -> Self {
self.attachments.push(attachment);
self
}
pub fn action(mut self, action: BotAction) -> Self {
self.actions.push(action);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BotAttachment {
pub attachment_type: String,
pub url: String,
pub title: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BotAction {
pub id: String,
pub label: String,
pub action_type: String,
pub value: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
struct MockAgent {
name: String,
}
#[async_trait]
impl AgentProtocol for MockAgent {
fn name(&self) -> &str {
&self.name
}
fn chat(&self, prompt: &str) -> Result<String> {
Ok(format!("Response to: {}", prompt))
}
async fn achat(&self, prompt: &str) -> Result<String> {
Ok(format!("Async response to: {}", prompt))
}
}
#[test]
fn test_mock_agent_protocol() {
let agent = MockAgent {
name: "TestAgent".to_string(),
};
assert_eq!(agent.name(), "TestAgent");
let response = agent.chat("Hello").unwrap();
assert!(response.contains("Hello"));
}
#[tokio::test]
async fn test_mock_agent_async() {
let agent = MockAgent {
name: "TestAgent".to_string(),
};
let response = agent.achat("Hello").await.unwrap();
assert!(response.contains("Async"));
}
#[test]
fn test_memory_message() {
let msg = MemoryMessage::new("user", "Hello");
assert_eq!(msg.role, "user");
assert_eq!(msg.content, "Hello");
assert!(msg.timestamp > 0);
}
#[test]
fn test_llm_message() {
let system = LlmMessage::system("You are helpful");
let user = LlmMessage::user("Hello");
let assistant = LlmMessage::assistant("Hi there");
assert_eq!(system.role, "system");
assert_eq!(user.role, "user");
assert_eq!(assistant.role, "assistant");
}
#[test]
fn test_agent_os_config() {
let config = AgentOSConfig::new("https://custom.url")
.api_key("test-key")
.agent_id("agent-123")
.no_telemetry();
assert_eq!(config.url, "https://custom.url");
assert_eq!(config.api_key, Some("test-key".to_string()));
assert_eq!(config.agent_id, Some("agent-123".to_string()));
assert!(!config.telemetry);
}
#[test]
fn test_agent_metrics() {
let metrics = AgentMetrics::new()
.custom("latency_p99", 150.0)
.custom("cache_hit_rate", 0.85);
assert_eq!(metrics.custom.len(), 2);
assert_eq!(metrics.custom.get("latency_p99"), Some(&150.0));
}
#[test]
fn test_bot_message() {
let msg = BotMessage::new("user123", "Hello bot");
assert_eq!(msg.sender_id, "user123");
assert_eq!(msg.content, "Hello bot");
assert!(!msg.id.is_empty());
}
#[test]
fn test_bot_response() {
let response = BotResponse::text("Hello!")
.reply_to("msg-123")
.action(BotAction {
id: "btn1".to_string(),
label: "Click me".to_string(),
action_type: "button".to_string(),
value: None,
});
assert_eq!(response.content, "Hello!");
assert_eq!(response.reply_to, Some("msg-123".to_string()));
assert_eq!(response.actions.len(), 1);
}
#[test]
fn test_tool_schema() {
let schema = ToolSchema {
name: "search".to_string(),
description: "Search the web".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"}
}
}),
};
assert_eq!(schema.name, "search");
}
#[test]
fn test_llm_response() {
let response = LlmResponse {
content: "Hello!".to_string(),
tool_calls: vec![ToolCall {
id: "call-1".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"query": "test"}),
}],
usage: Some(TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
}),
};
assert_eq!(response.tool_calls.len(), 1);
assert_eq!(response.usage.unwrap().total_tokens, 15);
}
}