use super::{AIMessage, AIProviderTrait, CompletionRequest, ModelConfig};
use crate::protocol::{Message, MessageType, Channel, ChannelType};
use crate::{Error, Result, Identity, Fingerprint};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum AgentCapability {
DirectMessage,
Mentions,
ChannelParticipation,
Commands,
BlockchainQuery,
ForumPost,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub id: String,
pub name: String,
pub model: ModelConfig,
pub capabilities: HashSet<AgentCapability>,
pub channels: Vec<String>,
pub response_prefix: Option<String>,
pub max_context: usize,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
id: "default-agent".to_string(),
name: "QComm Assistant".to_string(),
model: ModelConfig::default(),
capabilities: [
AgentCapability::DirectMessage,
AgentCapability::Mentions,
]
.into_iter()
.collect(),
channels: vec!["#quantum".to_string()],
response_prefix: None,
max_context: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct AgentContext {
pub channel: Option<Channel>,
pub recent_messages: Vec<Message>,
pub mentions: Vec<String>,
pub is_dm: bool,
}
pub struct AIAgent {
config: AgentConfig,
identity: Identity,
provider: Option<Arc<dyn AIProviderTrait>>,
joined_channels: Arc<Mutex<HashSet<String>>>,
history: Arc<Mutex<HashMap<String, Vec<AIMessage>>>>,
}
impl AIAgent {
pub fn new(config: AgentConfig) -> Result<Self> {
let identity = Identity::generate()?;
Ok(Self {
config,
identity,
provider: None,
joined_channels: Arc::new(Mutex::new(HashSet::new())),
history: Arc::new(Mutex::new(HashMap::new())),
})
}
pub fn with_identity(config: AgentConfig, identity: Identity) -> Self {
Self {
config,
identity,
provider: None,
joined_channels: Arc::new(Mutex::new(HashSet::new())),
history: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn set_provider(&mut self, provider: Arc<dyn AIProviderTrait>) {
self.provider = Some(provider);
}
pub fn id(&self) -> &str {
&self.config.id
}
pub fn name(&self) -> &str {
&self.config.name
}
pub fn fingerprint(&self) -> &Fingerprint {
self.identity.fingerprint()
}
pub fn has_capability(&self, cap: AgentCapability) -> bool {
self.config.capabilities.contains(&cap)
}
pub async fn join_channel(&self, channel: &str) {
self.joined_channels.lock().await.insert(channel.to_string());
tracing::info!("Agent {} joined channel {}", self.config.name, channel);
}
pub async fn leave_channel(&self, channel: &str) {
self.joined_channels.lock().await.remove(channel);
tracing::info!("Agent {} left channel {}", self.config.name, channel);
}
pub async fn is_in_channel(&self, channel: &str) -> bool {
self.joined_channels.lock().await.contains(channel)
}
pub async fn process_message(&self, msg: &Message, context: AgentContext) -> Result<Option<Message>> {
let should_respond = self.should_respond(msg, &context);
if !should_respond {
return Ok(None);
}
let messages = self.build_context(msg, &context).await;
let response_text = self.generate_response(messages).await?;
let mut response = Message::text(
&self.identity.fingerprint().to_hex(),
&msg.sender,
response_text,
);
response.metadata.agent_id = Some(self.config.id.clone());
response.reply_to = Some(msg.id.clone());
Ok(Some(response))
}
fn should_respond(&self, msg: &Message, context: &AgentContext) -> bool {
if msg.sender == self.identity.fingerprint().to_hex() {
return false;
}
if context.is_dm && !self.has_capability(AgentCapability::DirectMessage) {
return false;
}
if !context.is_dm {
if !self.has_capability(AgentCapability::Mentions) {
return false;
}
let our_fp = self.identity.fingerprint().to_hex();
if !context.mentions.contains(&our_fp) && !context.mentions.contains(&self.config.name) {
return false;
}
}
true
}
async fn build_context(&self, msg: &Message, context: &AgentContext) -> Vec<AIMessage> {
let mut messages = Vec::new();
let system_prompt = self.config.model.system_prompt.clone().unwrap_or_else(|| {
format!(
"You are {}, an AI assistant in the Quantum Communicator chat app. \
You help users with questions about the app, blockchain, and general topics. \
Keep responses concise and helpful.",
self.config.name
)
});
messages.push(AIMessage::system(system_prompt));
let channel_id = if context.is_dm {
&msg.sender
} else {
context.channel.as_ref().map(|c| c.id.as_str()).unwrap_or("unknown")
};
let history = self.history.lock().await;
if let Some(channel_history) = history.get(channel_id) {
let start = channel_history.len().saturating_sub(self.config.max_context);
messages.extend(channel_history[start..].to_vec());
}
if let MessageType::Text = msg.msg_type {
if let crate::protocol::message::MessageContent::Text(text) = &msg.content {
messages.push(AIMessage::user(format!("{}: {}", msg.sender, text)));
}
}
messages
}
async fn generate_response(&self, messages: Vec<AIMessage>) -> Result<String> {
let provider = self.provider.as_ref().ok_or_else(|| {
Error::AiAgent("No AI provider configured".to_string())
})?;
let request = CompletionRequest {
messages,
config: self.config.model.clone(),
};
let response = provider.complete(request).await?;
let text = if let Some(prefix) = &self.config.response_prefix {
format!("{} {}", prefix, response.content)
} else {
response.content
};
Ok(text)
}
pub async fn handle_command(&self, command: &str, args: &[&str]) -> Result<String> {
if !self.has_capability(AgentCapability::Commands) {
return Err(Error::AiAgent("Commands not enabled".to_string()));
}
match command {
"help" => Ok(self.help_text()),
"status" => Ok(self.status_text().await),
"channels" => Ok(self.channels_text().await),
_ => Ok(format!("Unknown command: {}", command)),
}
}
fn help_text(&self) -> String {
format!(
"**{}** - AI Assistant\n\n\
Available commands:\n\
- `/help` - Show this help\n\
- `/status` - Show agent status\n\
- `/channels` - List joined channels\n\n\
Mention me in a channel or send a DM to chat.",
self.config.name
)
}
async fn status_text(&self) -> String {
let channels = self.joined_channels.lock().await;
format!(
"Agent: {}\n\
ID: {}\n\
Fingerprint: {}\n\
Channels: {}\n\
Capabilities: {:?}",
self.config.name,
self.config.id,
self.identity.fingerprint(),
channels.len(),
self.config.capabilities
)
}
async fn channels_text(&self) -> String {
let channels = self.joined_channels.lock().await;
if channels.is_empty() {
"Not in any channels".to_string()
} else {
channels.iter().cloned().collect::<Vec<_>>().join(", ")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_creation() {
let config = AgentConfig::default();
let agent = AIAgent::new(config).unwrap();
assert!(agent.has_capability(AgentCapability::DirectMessage));
assert!(agent.has_capability(AgentCapability::Mentions));
}
#[tokio::test]
async fn test_join_channel() {
let agent = AIAgent::new(AgentConfig::default()).unwrap();
agent.join_channel("#test").await;
assert!(agent.is_in_channel("#test").await);
agent.leave_channel("#test").await;
assert!(!agent.is_in_channel("#test").await);
}
}