use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
pub struct ContextAgentConfig {
pub max_tokens: Option<usize>,
pub strategy: ContextStrategy,
pub include_system: bool,
pub include_tools: bool,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ContextStrategy {
#[default]
Recent,
Summarize,
Semantic,
Custom,
}
#[derive(Debug, Clone)]
pub struct ContextAgent {
pub config: ContextAgentConfig,
context: Vec<ContextEntry>,
max_entries: usize,
}
#[derive(Debug, Clone)]
pub struct ContextEntry {
pub role: String,
pub content: String,
pub timestamp: std::time::SystemTime,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Default for ContextAgent {
fn default() -> Self {
Self::new(ContextAgentConfig::default())
}
}
impl ContextAgent {
pub fn new(config: ContextAgentConfig) -> Self {
Self {
config,
context: Vec::new(),
max_entries: 100,
}
}
pub fn add(&mut self, role: impl Into<String>, content: impl Into<String>) {
let entry = ContextEntry {
role: role.into(),
content: content.into(),
timestamp: std::time::SystemTime::now(),
metadata: HashMap::new(),
};
self.context.push(entry);
if self.context.len() > self.max_entries {
self.context.remove(0);
}
}
pub fn get_context(&self) -> Vec<&ContextEntry> {
match self.config.strategy {
ContextStrategy::Recent => {
let max = self.config.max_tokens.unwrap_or(self.max_entries);
self.context.iter().rev().take(max).collect()
}
_ => self.context.iter().collect(),
}
}
pub fn clear(&mut self) {
self.context.clear();
}
pub fn len(&self) -> usize {
self.context.len()
}
pub fn is_empty(&self) -> bool {
self.context.is_empty()
}
}
pub fn create_context_agent() -> ContextAgent {
ContextAgent::default()
}
pub fn create_context_agent_with_config(config: ContextAgentConfig) -> ContextAgent {
ContextAgent::new(config)
}
#[derive(Debug, Clone, Default)]
pub struct PlanningAgentConfig {
pub max_steps: usize,
pub reasoning: bool,
pub llm: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PlanningStep {
pub step: usize,
pub description: String,
pub status: PlanningStepStatus,
pub result: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PlanningStepStatus {
#[default]
Pending,
InProgress,
Completed,
Failed,
Skipped,
}
#[derive(Debug, Clone)]
pub struct PlanningAgent {
pub config: PlanningAgentConfig,
plan: Vec<PlanningStep>,
current_step: usize,
}
impl Default for PlanningAgent {
fn default() -> Self {
Self::new(PlanningAgentConfig {
max_steps: 10,
reasoning: true,
llm: None,
})
}
}
impl PlanningAgent {
pub fn new(config: PlanningAgentConfig) -> Self {
Self {
config,
plan: Vec::new(),
current_step: 0,
}
}
pub fn create_plan(&mut self, task: &str) -> Vec<PlanningStep> {
self.plan = vec![PlanningStep {
step: 1,
description: format!("Execute task: {}", task),
status: PlanningStepStatus::Pending,
result: None,
}];
self.current_step = 0;
self.plan.clone()
}
pub fn add_step(&mut self, description: impl Into<String>) {
let step_num = self.plan.len() + 1;
self.plan.push(PlanningStep {
step: step_num,
description: description.into(),
status: PlanningStepStatus::Pending,
result: None,
});
}
pub fn current(&self) -> Option<&PlanningStep> {
self.plan.get(self.current_step)
}
pub fn complete_step(&mut self, result: Option<String>) {
if let Some(step) = self.plan.get_mut(self.current_step) {
step.status = PlanningStepStatus::Completed;
step.result = result;
}
self.current_step += 1;
}
pub fn fail_step(&mut self, error: impl Into<String>) {
if let Some(step) = self.plan.get_mut(self.current_step) {
step.status = PlanningStepStatus::Failed;
step.result = Some(error.into());
}
}
pub fn steps(&self) -> &[PlanningStep] {
&self.plan
}
pub fn is_complete(&self) -> bool {
self.current_step >= self.plan.len()
}
pub fn progress(&self) -> f32 {
if self.plan.is_empty() {
return 100.0;
}
(self.current_step as f32 / self.plan.len() as f32) * 100.0
}
}
#[derive(Debug, Clone, Default)]
pub struct FastContext {
entries: Vec<String>,
max_size: usize,
current_size: usize,
}
impl FastContext {
pub fn new(max_size: usize) -> Self {
Self {
entries: Vec::new(),
max_size,
current_size: 0,
}
}
pub fn add(&mut self, content: impl Into<String>) {
let content = content.into();
let content_len = content.len();
while self.current_size + content_len > self.max_size && !self.entries.is_empty() {
if let Some(removed) = self.entries.first() {
self.current_size -= removed.len();
}
self.entries.remove(0);
}
self.entries.push(content);
self.current_size += content_len;
}
pub fn get(&self) -> &[String] {
&self.entries
}
pub fn as_string(&self) -> String {
self.entries.join("\n")
}
pub fn clear(&mut self) {
self.entries.clear();
self.current_size = 0;
}
pub fn size(&self) -> usize {
self.current_size
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct AutoAgentsConfig {
pub task: String,
pub num_agents: Option<usize>,
pub llm: Option<String>,
pub verbose: bool,
}
#[derive(Debug, Clone)]
pub struct AutoAgents {
pub config: AutoAgentsConfig,
agents: Vec<AutoAgentSpec>,
}
#[derive(Debug, Clone)]
pub struct AutoAgentSpec {
pub name: String,
pub role: String,
pub goal: String,
pub backstory: Option<String>,
pub tools: Vec<String>,
}
impl AutoAgents {
pub fn new(config: AutoAgentsConfig) -> Self {
Self {
config,
agents: Vec::new(),
}
}
pub fn generate(&mut self, task: &str) -> &[AutoAgentSpec] {
self.agents = vec![AutoAgentSpec {
name: "Assistant".to_string(),
role: "General Assistant".to_string(),
goal: format!("Complete the task: {}", task),
backstory: None,
tools: Vec::new(),
}];
&self.agents
}
pub fn agents(&self) -> &[AutoAgentSpec] {
&self.agents
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct AutoRagConfig {
#[serde(default)]
pub enabled: bool,
pub chunk_size: Option<usize>,
pub chunk_overlap: Option<usize>,
pub embedding_model: Option<String>,
pub vector_store: Option<String>,
pub top_k: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct AutoRagAgent {
pub config: AutoRagConfig,
sources: Vec<String>,
indexed: bool,
}
impl AutoRagAgent {
pub fn new(config: AutoRagConfig) -> Self {
Self {
config,
sources: Vec::new(),
indexed: false,
}
}
pub fn add_source(&mut self, source: impl Into<String>) {
self.sources.push(source.into());
self.indexed = false;
}
pub fn add_sources(&mut self, sources: impl IntoIterator<Item = impl Into<String>>) {
for source in sources {
self.sources.push(source.into());
}
self.indexed = false;
}
pub fn index(&mut self) -> Result<(), String> {
self.indexed = true;
Ok(())
}
pub fn query(&self, query: &str) -> Result<Vec<String>, String> {
if !self.indexed {
return Err("Documents not indexed. Call index() first.".to_string());
}
Ok(vec![format!("Result for query: {}", query)])
}
pub fn is_indexed(&self) -> bool {
self.indexed
}
pub fn sources(&self) -> &[String] {
&self.sources
}
}
pub trait TraceSinkProtocol: Send + Sync {
fn write(&self, event: &TraceEvent);
fn flush(&self);
fn close(&self);
}
#[derive(Debug, Clone)]
pub struct TraceEvent {
pub event_type: String,
pub timestamp: std::time::SystemTime,
pub data: HashMap<String, serde_json::Value>,
pub trace_id: Option<String>,
pub span_id: Option<String>,
pub parent_span_id: Option<String>,
}
impl TraceEvent {
pub fn new(event_type: impl Into<String>) -> Self {
Self {
event_type: event_type.into(),
timestamp: std::time::SystemTime::now(),
data: HashMap::new(),
trace_id: None,
span_id: None,
parent_span_id: None,
}
}
pub fn trace_id(mut self, id: impl Into<String>) -> Self {
self.trace_id = Some(id.into());
self
}
pub fn span_id(mut self, id: impl Into<String>) -> Self {
self.span_id = Some(id.into());
self
}
pub fn data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.data.insert(key.into(), value);
self
}
}
#[derive(Debug, Default)]
pub struct ContextTraceSink {
events: std::sync::RwLock<Vec<TraceEvent>>,
max_events: usize,
}
impl ContextTraceSink {
pub fn new(max_events: usize) -> Self {
Self {
events: std::sync::RwLock::new(Vec::new()),
max_events,
}
}
pub fn events(&self) -> Vec<TraceEvent> {
self.events.read().unwrap().clone()
}
pub fn clear(&self) {
self.events.write().unwrap().clear();
}
}
impl TraceSinkProtocol for ContextTraceSink {
fn write(&self, event: &TraceEvent) {
let mut events = self.events.write().unwrap();
events.push(event.clone());
while events.len() > self.max_events {
events.remove(0);
}
}
fn flush(&self) {
}
fn close(&self) {
}
}
pub type TraceSink = Arc<dyn TraceSinkProtocol>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryBackend {
#[default]
InMemory,
Sqlite,
Postgres,
Redis,
Chroma,
Custom,
}
impl std::fmt::Display for MemoryBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InMemory => write!(f, "in_memory"),
Self::Sqlite => write!(f, "sqlite"),
Self::Postgres => write!(f, "postgres"),
Self::Redis => write!(f, "redis"),
Self::Chroma => write!(f, "chroma"),
Self::Custom => write!(f, "custom"),
}
}
}
impl std::str::FromStr for MemoryBackend {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"in_memory" | "inmemory" | "memory" => Ok(Self::InMemory),
"sqlite" => Ok(Self::Sqlite),
"postgres" | "postgresql" => Ok(Self::Postgres),
"redis" => Ok(Self::Redis),
"chroma" | "chromadb" => Ok(Self::Chroma),
"custom" => Ok(Self::Custom),
_ => Err(format!("Unknown memory backend: {}", s)),
}
}
}
pub type Tools = crate::tools::ToolRegistry;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_agent() {
let mut agent = create_context_agent();
assert!(agent.is_empty());
agent.add("user", "Hello");
agent.add("assistant", "Hi there!");
assert_eq!(agent.len(), 2);
assert!(!agent.is_empty());
let context = agent.get_context();
assert_eq!(context.len(), 2);
}
#[test]
fn test_planning_agent() {
let mut agent = PlanningAgent::default();
agent.create_plan("Test task");
assert!(!agent.is_complete());
assert_eq!(agent.progress(), 0.0);
agent.complete_step(Some("Done".to_string()));
assert!(agent.is_complete());
assert_eq!(agent.progress(), 100.0);
}
#[test]
fn test_fast_context() {
let mut ctx = FastContext::new(100);
assert!(ctx.is_empty());
ctx.add("Hello");
ctx.add("World");
assert!(!ctx.is_empty());
assert_eq!(ctx.get().len(), 2);
assert_eq!(ctx.as_string(), "Hello\nWorld");
}
#[test]
fn test_auto_rag_agent() {
let config = AutoRagConfig {
enabled: true,
chunk_size: Some(500),
..Default::default()
};
let mut agent = AutoRagAgent::new(config);
agent.add_source("doc1.pdf");
agent.add_source("doc2.txt");
assert_eq!(agent.sources().len(), 2);
assert!(!agent.is_indexed());
agent.index().unwrap();
assert!(agent.is_indexed());
}
#[test]
fn test_trace_event() {
let event = TraceEvent::new("test_event")
.trace_id("trace-123")
.span_id("span-456")
.data("key", serde_json::json!("value"));
assert_eq!(event.event_type, "test_event");
assert_eq!(event.trace_id, Some("trace-123".to_string()));
assert_eq!(event.span_id, Some("span-456".to_string()));
}
#[test]
fn test_context_trace_sink() {
let sink = ContextTraceSink::new(10);
let event = TraceEvent::new("test");
sink.write(&event);
let events = sink.events();
assert_eq!(events.len(), 1);
sink.clear();
assert!(sink.events().is_empty());
}
#[test]
fn test_memory_backend_parse() {
assert_eq!("sqlite".parse::<MemoryBackend>().unwrap(), MemoryBackend::Sqlite);
assert_eq!("postgres".parse::<MemoryBackend>().unwrap(), MemoryBackend::Postgres);
assert_eq!("redis".parse::<MemoryBackend>().unwrap(), MemoryBackend::Redis);
}
}