use std::collections::HashMap;
use std::io::Write;
use std::sync::Arc;
use std::time::Instant;
use console::style;
use futures::Stream;
use futures_util::StreamExt;
use tokio::sync::Mutex;
use crate::memory::injector::build_injected_prompt;
use crate::memory::store::MemoryStore;
use crate::messages::{Message, ToolCall};
use crate::provider::Provider;
use crate::tool::ToolSpec;
use crate::ui;
#[derive(Debug, Clone)]
pub enum StreamEvent {
Token(String),
ToolStart { name: String, args: String },
ToolEnd { name: String, success: bool },
Iteration(()),
Usage(crate::messages::Usage),
Done,
Error(String),
}
pub struct Agent {
pub provider: Arc<dyn Provider>,
pub providers_map: HashMap<String, serde_json::Value>,
pub provider_names: Vec<String>,
pub active_provider: String,
pub active_model: String,
pub session_usage: std::sync::Mutex<crate::messages::Usage>,
pub tools: Arc<HashMap<String, Arc<ToolSpec>>>,
pub openai_tool_defs: Option<Vec<serde_json::Value>>,
pub memory_store: Option<Arc<std::sync::Mutex<MemoryStore>>>,
pub system_prompt: String,
pub max_iterations: u32,
pub max_tokens: Option<u32>,
pub temperature: f32,
pub verbose: bool,
pub messages: Arc<Mutex<Vec<Message>>>,
tool_errors: std::sync::Mutex<HashMap<String, Vec<String>>>,
session_start: Option<Instant>,
turn_count: std::sync::Mutex<u64>,
ephemeral: std::sync::Mutex<Vec<crate::memory::store::EphemeralEntry>>,
tool_usage_history: std::sync::Mutex<Vec<(String, String)>>,
}
impl Agent {
pub fn new(
provider: Box<dyn Provider>,
tools: Vec<ToolSpec>,
memory_store: Option<MemoryStore>,
system_prompt: &str,
max_iterations: u32,
max_tokens: Option<u32>,
temperature: f32,
verbose: bool,
providers_map: HashMap<String, serde_json::Value>,
provider_names: Vec<String>,
active_provider: String,
active_model: String,
) -> Self {
let tool_list: Vec<Arc<ToolSpec>> = tools.into_iter().map(Arc::new).collect();
let tools_map: HashMap<String, Arc<ToolSpec>> = tool_list.iter().map(|t| (t.name.clone(), Arc::clone(t))).collect();
let openai_tool_defs = if tool_list.is_empty() { None } else { Some(tool_list.iter().map(|t| t.to_openai_tool()).collect()) };
Self {
provider: Arc::from(provider),
providers_map,
provider_names,
active_provider,
active_model,
session_usage: std::sync::Mutex::new(crate::messages::Usage::default()),
tools: Arc::new(tools_map),
openai_tool_defs,
memory_store: memory_store.map(|m| Arc::new(std::sync::Mutex::new(m))),
system_prompt: system_prompt.to_string(),
max_iterations,
max_tokens,
temperature,
verbose,
messages: Arc::new(Mutex::new(Vec::new())),
tool_errors: std::sync::Mutex::new(HashMap::new()),
session_start: None,
turn_count: std::sync::Mutex::new(0),
ephemeral: std::sync::Mutex::new(Vec::new()),
tool_usage_history: std::sync::Mutex::new(Vec::new()),
}
}
fn build_enriched_prompt(&self, user_input: &str) -> String {
let provider_url = self.providers_map.get(&self.active_provider)
.and_then(|e| e.get("base_url"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let identity = format!(
"\n\n## Current Configuration\n- Provider: {}\n- Model: {}\n- API: {}\n",
self.active_provider, self.active_model, provider_url
);
let base = format!("{}{}", self.system_prompt, identity);
let ephemeral = self.ephemeral.lock().ok().map(|e| e.clone()).unwrap_or_default();
if let Some(ref store_mutex) = self.memory_store {
if let Ok(store) = store_mutex.lock() {
return build_injected_prompt(&base, &store, user_input, &[], 15, 5, &ephemeral);
}
}
base
}
async fn init_turn(&self, user_input: &str, reset: bool) {
let prompt = self.build_enriched_prompt(user_input);
let mut msgs = self.messages.lock().await;
if reset || msgs.is_empty() {
*msgs = vec![Message::new_system(prompt)];
} else {
msgs[0] = Message::new_system(prompt);
}
msgs.push(Message::new_user(user_input));
}
pub async fn run(&self, user_input: &str, reset: bool) -> String {
self.init_turn(user_input, reset).await;
let tool_errors = std::sync::Mutex::new(HashMap::new());
for _iter in 0..self.max_iterations {
let msgs_snapshot = self.messages.lock().await.clone();
let response = match self.provider.chat_completion(&msgs_snapshot, self.openai_tool_defs.as_deref(), "auto", self.max_tokens, self.temperature).await {
Ok(r) => r,
Err(e) => return format!("Provider error: {}", e),
};
if self.verbose {
let preview = response.content.as_ref().map(|c| c.chars().take(200).collect::<String>()).unwrap_or_default();
eprintln!(" [llm → {}] {}", response.role, preview);
}
let has_tool_calls = response.tool_calls.is_some();
self.messages.lock().await.push(response);
if !has_tool_calls {
return self.messages.lock().await.last().and_then(|m| m.content.clone()).unwrap_or_default();
}
let calls = self.messages.lock().await.last().and_then(|m| m.tool_calls.clone()).unwrap_or_default();
for tc in &calls {
let result = execute_tool(tc, &self.tools, &self.memory_store, self.verbose, &tool_errors);
self.messages.lock().await.push(Message::new_tool(result, &tc.id, &tc.name));
}
}
self.messages.lock().await.last().map(|m| m.content.clone().unwrap_or_default())
.unwrap_or_else(|| format!("(Reached max iterations={})", self.max_iterations))
}
pub fn run_stream(&self, user_input: String, _reset: bool) -> impl Stream<Item = StreamEvent> + Send + 'static {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<StreamEvent>();
let messages = self.messages.clone();
let provider = self.provider.clone();
let tools = self.tools.clone();
let tool_defs = self.openai_tool_defs.clone();
let memory_store = self.memory_store.clone();
let verbose = self.verbose;
let max_iterations = self.max_iterations;
let max_tokens = self.max_tokens;
let temperature = self.temperature;
let prompt = self.build_enriched_prompt(&user_input);
tokio::spawn(async move {
{
let mut msgs = messages.lock().await;
*msgs = vec![Message::new_system(prompt), Message::new_user(&user_input)];
}
let tool_errors = std::sync::Mutex::new(HashMap::new());
for _iteration in 0..max_iterations {
let _ = tx.send(StreamEvent::Iteration(()));
let msgs_snapshot = messages.lock().await.clone();
let mut stream = match provider.chat_completion_stream(&msgs_snapshot, tool_defs.as_deref(), "auto", max_tokens, temperature).await {
Ok(s) => s,
Err(e) => {
let user_msg = if e.to_string().contains("error sending request") || e.to_string().contains("dns") || e.to_string().contains("resolve") {
format!("No connection to current provider.\n Use /switch <name> to pick a working provider.\n Or check config.yaml → active_provider and api_key.")
} else {
format!("{}\n Use /switch <name> to try another provider.", e.to_string())
};
let _ = tx.send(StreamEvent::Error(user_msg)); break; }
};
let mut content = String::new();
while let Some(token_result) = stream.next().await {
match token_result {
Ok(token) => { content.push_str(&token); let _ = tx.send(StreamEvent::Token(token)); }
Err(e) => { let _ = tx.send(StreamEvent::Error(e.to_string())); }
}
}
let msg = provider.last_stream_message().unwrap_or_else(|| Message::new_assistant(Some(content.clone()), None));
messages.lock().await.push(msg);
if let Some(u) = provider.last_usage() {
let _ = tx.send(StreamEvent::Usage(u));
}
if messages.lock().await.last().and_then(|m| m.tool_calls.as_ref()).is_none() {
let _ = tx.send(StreamEvent::Done); return;
}
let calls = messages.lock().await.last().and_then(|m| m.tool_calls.clone()).unwrap_or_default();
for tc in &calls {
let args_json = serde_json::to_string(&tc.arguments).unwrap_or_default();
let _ = tx.send(StreamEvent::ToolStart { name: tc.name.clone(), args: args_json });
let result = execute_tool(tc, &tools, &memory_store, verbose, &tool_errors);
let success = !result.starts_with("Error");
let _ = tx.send(StreamEvent::ToolEnd { name: tc.name.clone(), success });
messages.lock().await.push(Message::new_tool(result, &tc.id, &tc.name));
}
}
let _ = tx.send(StreamEvent::Done);
});
tokio_stream::wrappers::UnboundedReceiverStream::new(rx)
}
pub async fn reset(&self) {
let mut msgs = self.messages.lock().await;
msgs.clear();
if let Ok(mut errors) = self.tool_errors.lock() { errors.clear(); }
}
#[allow(dead_code)]
pub async fn switch_provider(&mut self, name: &str) -> bool {
if let Some(entry) = self.providers_map.get(name) {
let api_key = entry.get("api_key").and_then(|v| v.as_str()).unwrap_or("");
let base_url = entry.get("base_url").and_then(|v| v.as_str()).unwrap_or("https://api.openai.com/v1");
let model = entry.get("model").and_then(|v| v.as_str()).unwrap_or("gpt-4o");
match crate::providers::openai_compat::create_provider("openai", model, api_key, Some(base_url)) {
Ok(new_provider) => {
self.provider = Arc::from(new_provider);
self.reset().await;
true
}
Err(e) => { if self.verbose { eprintln!("Failed to switch: {}", e); } false }
}
} else { false }
}
async fn activate_provider(&mut self, name: &str, model: &str, api_key: &str, base_url: &str) {
if let Ok(new_provider) = crate::providers::openai_compat::create_provider("openai", model, api_key, Some(base_url)) {
self.provider = Arc::from(new_provider);
self.reset().await;
if let Some(ref mut entry) = self.providers_map.get_mut(name) {
if let Some(obj) = entry.as_object_mut() {
obj.insert("model".into(), serde_json::Value::String(model.to_string()));
}
}
save_model_to_config(name, model).ok();
self.active_provider = name.to_string();
self.active_model = model.to_string();
save_active_config(&self.active_provider, &self.active_model).ok();
ui::render_info(&format!("Switched to {} ({})", name, model));
} else {
ui::render_error(&format!("Failed to switch to '{}'.", name));
}
}
pub async fn chat(&mut self) {
self.reset().await;
self.session_start = Some(Instant::now());
if let Ok(mut tc) = self.turn_count.lock() { *tc = 0; }
let interrupted = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let intr = interrupted.clone();
tokio::spawn(async move {
loop {
tokio::signal::ctrl_c().await.ok();
intr.store(true, std::sync::atomic::Ordering::Relaxed);
eprintln!("\n {} Interrupt — press Ctrl+C again to force quit, or type /exit", style("⚠").yellow());
}
});
let recent = self.memory_store.as_ref().and_then(|m| {
m.lock().ok().and_then(|store| store.recent_memories(3).ok())
}).unwrap_or_default();
let session = crate::session::load_session_state();
ui::welcome(
self.tools.len(), self.memory_store.is_some(),
&self.active_provider, &self.active_model, &recent,
session.last_session_turns, session.last_session_tokens, session.last_session_duration_secs,
);
let mut rl = ui::create_editor();
loop {
let turn_number = self.turn_count.lock().map(|t| *t).unwrap_or(0);
let width = ui::draw_full_box(&self.active_provider, &self.active_model, turn_number);
let input = match ui::readline(&mut rl, width) {
Some(s) => s,
None => { println!(); break; }
};
if input.is_empty() { continue; }
ui::box_enter(width);
let display_input = if input.contains('\n') {
let line_count = input.lines().count();
let preview: String = input.lines().next().unwrap_or("").chars().take(50).collect();
format!("{} … ({} lines)", preview, line_count)
} else {
input.clone()
};
ui::print_user_turn(&display_input, width);
if let Ok(mut eph) = self.ephemeral.lock() {
let turn = self.turn_count.lock().map(|t| *t).unwrap_or(0);
let words: Vec<&str> = input.split_whitespace().collect();
for w in words.iter().take(3) {
if w.len() > 4 {
let topic = w.to_string();
let detail = input.chars().take(80).collect::<String>();
if !eph.iter().any(|e| e.topic == topic) {
eph.push(crate::memory::store::EphemeralEntry { topic, detail, turn });
}
}
}
while eph.len() > 8 { eph.remove(0); }
}
if let Some(ref store_mutex) = self.memory_store {
let prefs = crate::memory::store::extract_preferences(&input);
if !prefs.is_empty() {
if let Ok(store) = store_mutex.lock() {
for (cat, content, imp) in prefs {
let _ = store.save_memory("user", &content, &cat, imp, &[]);
}
}
}
}
let cmd = input.to_lowercase();
match cmd.as_str() {
"exit" | "quit" | "/exit" => break,
"/reset" | "/clear" => {
self.reset().await;
ui::render_info("Conversation cleared — memory intact");
println!();
continue;
}
cmd if cmd == "/memory" || cmd.starts_with("/memory ") => {
if let Some(ref store_mutex) = self.memory_store {
if let Ok(store) = store_mutex.lock() {
let sub = cmd.trim_start_matches("/memory").trim();
match sub {
"tree" | "--tree" | "-t" => {
if let Ok(entries) = store.list_memories(None, 100) {
ui::render_memory_tree(&entries);
}
}
"timeline" | "--timeline" | "-l" => {
if let Ok(entries) = store.list_memories(None, 50) {
ui::render_memory_timeline(&entries);
}
}
"stats" | "--stats" | "-s" => {
if let Ok(stats) = store.memory_stats() {
let total = stats.get("total").copied().unwrap_or(0);
println!(" {} Memory Stats ({} total)", style("◈").cyan(), total);
println!(" {}", style(crate::ui::DIVIDER.repeat(40)).dim());
for (key, count) in stats.iter().filter(|(k,_)| *k != "total") {
println!(" {:>20}: {}", key, count);
}
println!();
}
}
_ => {
if let Ok(entries) = store.list_memories(None, 20) {
ui::render_memory_table(&entries);
}
}
}
}
} else { ui::render_info("Memory is not enabled"); }
println!(); continue;
}
cmd if cmd == "/skills" || cmd.starts_with("/skills ") => {
if let Some(ref store_mutex) = self.memory_store {
if let Ok(store) = store_mutex.lock() {
let sub = cmd.trim_start_matches("/skills").trim();
match sub {
"add" | "create" | "--add" => {
use dialoguer::Input;
let name: String = Input::new().with_prompt("Skill name").interact_text().unwrap_or_default();
if !name.is_empty() {
let desc: String = Input::new().with_prompt("Description").interact_text().unwrap_or_default();
println!(" {} Enter skill content (markdown). Type /done on a new line when finished:", style("ℹ").dim());
let mut content = String::new();
loop {
let line: String = Input::new().with_prompt("").interact_text().unwrap_or_default();
if line.trim() == "/done" { break; }
content.push_str(&line);
content.push('\n');
}
if store.get_skill(&name).ok().flatten().is_none() {
let _ = store.save_skill(&name, &desc, &content, "user");
println!(" {} Skill '{}' created.", style("✓").green(), name);
} else {
println!(" {} Skill '{}' already exists.", style("✗").red(), name);
}
}
}
"delete" | "rm" | "--delete" => {
use dialoguer::Input;
let name: String = Input::new().with_prompt("Skill name to delete").interact_text().unwrap_or_default();
if !name.is_empty() && store.delete_skill(&name).unwrap_or(false) {
println!(" {} Skill '{}' deleted.", style("✓").green(), name);
} else if !name.is_empty() {
println!(" {} Skill '{}' not found.", style("✗").red(), name);
}
}
_ => {
if let Ok(entries) = store.list_skills(None) { ui::render_skill_table(&entries); }
}
}
}
} else { ui::render_info("Memory is not enabled"); }
println!(); continue;
}
"/tip" => { ui::show_tip(); println!(); continue; }
cmd if cmd == "/tools" || cmd.starts_with("/tools ") => {
use std::collections::BTreeMap;
let mut by_toolset: BTreeMap<&str, Vec<String>> = BTreeMap::new();
for tool in self.tools.values() {
let cat = if tool.name.contains("memory") || tool.name.contains("skill") { "Memory"
} else if tool.name.contains("git") { "Git"
} else if tool.name.contains("web") { "Web"
} else if tool.name.contains("execute") || tool.name.contains("bash") { "Code"
} else if tool.name.contains("file") || tool.name.contains("read") || tool.name.contains("write") { "Files"
} else { "Utility" };
by_toolset.entry(cat).or_default().push(format!("{} — {}", style(&tool.name).bold(), tool.description));
}
println!(" {} {}", style("Available Tools").bold().white(), format!("({} total)", self.tools.len()));
println!(" {}", style(crate::ui::DIVIDER.repeat(40)).dim());
for (cat, items) in &by_toolset {
println!(" {} {}", style("▸").cyan(), style(*cat).bold());
for item in items {
println!(" {}", item);
}
println!();
}
println!(); continue;
}
cmd if cmd == "/theme" || cmd.starts_with("/theme ") => {
let sub = cmd.trim_start_matches("/theme").trim();
match sub {
"latte" | "light" => { ui::set_theme(ui::Theme::Latte); ui::render_info("Theme set to Latte"); }
"mocha" | "dark" => { ui::set_theme(ui::Theme::Mocha); ui::render_info("Theme set to Mocha"); }
_ => {
let current = ui::get_theme();
println!(" Current theme: {:?}", current);
println!(" Usage: /theme <mocha|latte>");
}
}
println!(); continue;
}
cmd if cmd == "/plugin" || cmd.starts_with("/plugin ") => {
let sub = cmd.trim_start_matches("/plugin").trim();
match sub {
"list" | "ls" | "" => {
let states = crate::plugin::list_plugin_states();
if states.is_empty() {
println!(" No plugins found. Add .yaml files to ~/.cortex/plugins/");
} else {
println!(" {} Plugins", style("🔌").dim());
println!(" {}", style(crate::ui::DIVIDER.repeat(40)).dim());
for (name, enabled) in &states {
let indicator = if *enabled { "●" } else { "○" };
println!(" {} {} {}", style(indicator).green(), name, if *enabled { style("enabled").green() } else { style("disabled").dim() });
}
}
}
"enable" | "on" => {
println!(" Usage: edit ~/.cortex/plugins/<name>.yaml and set enabled: true");
}
rest if rest.starts_with("install") || rest.starts_with("add") => {
let url_part = rest.trim_start_matches("install").trim().trim_start_matches("add").trim();
if url_part.is_empty() {
println!(" Usage: /plugin install <url>");
println!(" Examples:");
println!(" /plugin install https://raw.githubusercontent.com/user/repo/main/plugin.yaml");
println!(" /plugin install https://github.com/user/repo");
} else {
match crate::plugin::install_plugin_from_url(url_part) {
Ok(msg) => println!(" {} {}", style("✓").green(), msg),
Err(e) => println!(" {} Failed: {}", style("✗").red(), e),
}
}
}
_ => {
println!(" Usage: /plugin list");
}
}
println!(); continue;
}
"/model" => {
let entry = self.providers_map.get(&self.active_provider).cloned().unwrap_or_default();
let base_url = entry.get("base_url").and_then(|v| v.as_str()).unwrap_or("https://api.openai.com/v1").to_string();
let current_key = entry.get("api_key").and_then(|v| v.as_str()).unwrap_or("").to_string();
let api_key_set = !current_key.is_empty() && !current_key.contains("${");
ui::render_model_info(&self.active_provider, &self.active_model, &base_url, api_key_set);
println!(); continue;
}
"/stats" | "/cost" => {
if let Ok(usage) = self.session_usage.lock() {
let cost = usage.session_cost(&self.active_model);
let elapsed = self.session_start.map(|s| s.elapsed());
let tc = self.turn_count.lock().map(|t| *t).unwrap_or(0);
ui::render_stats(
usage.session_prompt, usage.session_completion, usage.session_total,
cost,
&self.active_provider, &self.active_model,
self.tools.len(), self.memory_store.is_some(),
tc, elapsed,
);
}
println!(); continue;
}
cmd if cmd == "/help" || cmd.starts_with("/help ") || cmd == "help" => {
let topic = cmd.trim_start_matches("/help").trim().trim_start_matches("help").trim();
if topic.is_empty() {
ui::show_help();
} else {
let full_cmd = if topic.starts_with('/') { topic.to_string() } else { format!("/{}", topic) };
ui::show_command_help(&full_cmd);
}
continue;
}
"/provider" => {
use dialoguer::{Select, Confirm, Input as DInput};
let provider_items: Vec<String> = self.provider_names.iter().map(|p| {
if p == &self.active_provider {
format!("{} {} {}", style("●").cyan(), style(p).bold().white(), style("(active)").dim().cyan())
} else {
format!(" {}", p)
}
}).collect();
let selection = Select::new()
.with_prompt("Select provider (↑↓ arrows, Enter to confirm, Esc to cancel)")
.items(&provider_items)
.default(self.provider_names.iter().position(|p| p == &self.active_provider).unwrap_or(0))
.interact_opt();
let idx = match selection {
Ok(Some(i)) => i,
_ => { println!(); continue; }
};
let name = self.provider_names[idx].clone();
let entry = self.providers_map.get(&name).cloned().unwrap_or_default();
let current_key = entry.get("api_key").and_then(|v| v.as_str()).unwrap_or("").to_string();
let base_url = entry.get("base_url").and_then(|v| v.as_str()).unwrap_or("https://api.openai.com/v1").to_string();
let current_model = entry.get("model").and_then(|v| v.as_str()).unwrap_or("").to_string();
let has_key = !current_key.is_empty() && !current_key.contains("${");
if has_key {
let masked = if current_key.len() > 8 { format!("{}...", ¤t_key[..8]) } else { "***".into() };
println!("\n {} Key: {} (saved)", style("🔑").dim(), style(masked).dim());
if !Confirm::new().with_prompt("Replace key?").default(false).interact().unwrap_or(false) {
} else {
let key: String = DInput::new().with_prompt("New API key").interact_text().unwrap_or_default();
if key.is_empty() { println!(); continue; }
if let Some(ref mut entry) = self.providers_map.get_mut(&name) {
if let Some(obj) = entry.as_object_mut() {
obj.insert("api_key".into(), serde_json::Value::String(key.clone()));
}
}
save_api_key_to_config(&name, &key).ok();
}
} else {
let key: String = DInput::new().with_prompt("API key").interact_text().unwrap_or_default();
if key.is_empty() { println!(); continue; }
if let Some(ref mut entry) = self.providers_map.get_mut(&name) {
if let Some(obj) = entry.as_object_mut() {
obj.insert("api_key".into(), serde_json::Value::String(key.clone()));
}
}
save_api_key_to_config(&name, &key).ok();
}
let api_key = self.providers_map.get(&name)
.and_then(|e| e.get("api_key")).and_then(|v| v.as_str()).unwrap_or("").to_string();
let mut models = fetch_provider_models(&base_url, &api_key).await;
if models.is_empty() {
models = ui::provider_models(&name).into_iter().map(String::from).collect();
}
if models.is_empty() {
let model_input: String = DInput::new().with_prompt("Model name").interact_text().unwrap_or_default();
if model_input.is_empty() { println!(); continue; }
self.activate_provider(&name, &model_input, &api_key, &base_url).await;
} else {
let default_model_idx = if !current_model.is_empty() {
models.iter().position(|m| m == ¤t_model).unwrap_or(0)
} else { 0 };
let model_items: Vec<String> = models.iter().map(|m| {
if m == ¤t_model { format!("{} {} {}", style("●").cyan(), m, style("(current)").dim()) }
else { format!(" {}", m) }
}).collect();
let model_sel = Select::new()
.with_prompt("Select model (↑↓ arrows, Enter to confirm)")
.items(&model_items)
.default(default_model_idx)
.interact_opt();
match model_sel {
Ok(Some(mi)) => {
let model = &models[mi];
self.activate_provider(&name, model, &api_key, &base_url).await;
}
_ => { println!(); continue; }
}
}
println!(); continue;
}
_ => {}
}
if let Some(name) = cmd.strip_prefix("/switch ") {
let name = name.trim().to_string();
if !self.providers_map.contains_key(&name) {
ui::render_error(&format!("Unknown provider '{}'. Use /provider to list.", name));
println!(); continue;
}
let entry = self.providers_map.get(&name).cloned().unwrap_or_default();
let current_key = entry.get("api_key").and_then(|v| v.as_str()).unwrap_or("").to_string();
let needs_key = current_key.is_empty() || current_key.contains("${");
if needs_key {
ui::render_info(&format!("No API key configured for '{}'.", name));
let key = ui::prompt_input(&format!("Enter API key for {}: ", name));
if key.is_empty() {
ui::render_error("No key entered. Provider not switched.");
println!(); continue;
}
if let Some(ref mut entry) = self.providers_map.get_mut(&name) {
if let Some(obj) = entry.as_object_mut() {
obj.insert("api_key".into(), serde_json::Value::String(key.clone()));
}
}
save_api_key_to_config(&name, &key).ok();
ui::render_info("API key saved to config.yaml");
}
let model = ui::prompt_input(&format!("Model for {} (e.g. meta/llama-3.1-8b-instruct): ", name));
let model = if model.is_empty() {
"gpt-4o"
} else {
if let Some(ref mut entry) = self.providers_map.get_mut(&name) {
if let Some(obj) = entry.as_object_mut() {
obj.insert("model".into(), serde_json::Value::String(model.clone()));
}
}
&model
};
let api_key = self.providers_map.get(&name)
.and_then(|e| e.get("api_key"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let base_url = self.providers_map.get(&name)
.and_then(|e| e.get("base_url"))
.and_then(|v| v.as_str())
.unwrap_or("https://api.openai.com/v1")
.to_string();
if let Ok(new_provider) = crate::providers::openai_compat::create_provider("openai", model, &api_key, Some(&base_url)) {
self.provider = Arc::from(new_provider);
self.reset().await;
ui::render_info(&format!("Switched to {} ({})", name, model));
if !model.is_empty() && model != "gpt-4o" {
save_model_to_config(&name, model).ok();
}
self.active_provider = name.clone();
self.active_model = model.to_string();
save_active_config(&self.active_provider, &self.active_model).ok();
} else {
ui::render_error(&format!("Failed to switch to provider '{}'.", name));
}
println!(); continue;
}
if cmd.starts_with('/') {
ui::render_info(&format!("Unknown command '{}'. Type /help.", cmd));
println!(); continue;
}
let lower = input.to_lowercase();
let typo_fixes = [
("swtich", "/switch"), ("swich", "/switch"), ("switc", "/switch"),
("changeprovider", "/switch"), ("provider", "/provider"),
("hel", "/help"), ("hlep", "/help"), ("command", "/help"),
("mem", "/memory"), ("memories", "/memory"),
("skil", "/skills"), ("skill", "/skills"),
("res", "/reset"), ("rest", "/reset"), ("clear", "/reset"),
("ext", "/exit"), ("qui", "/quit"),
("tip", "/tip"), ("hint", "/tip"),
("modle", "/model"), ("mdel", "/model"),
("stat", "/stats"), ("cost", "/stats"),
];
let mut suggestion = None;
for (typo, fix) in &typo_fixes {
if lower == *typo || lower.starts_with(typo) {
suggestion = Some(fix);
break;
}
}
if let Some(fix) = suggestion {
ui::render_info(&format!("Did you mean '{}'?", fix));
println!(); continue;
}
let _turn_start = Instant::now();
if let Ok(mut tc) = self.turn_count.lock() { *tc += 1; }
let turn_number = self.turn_count.lock().map(|t| *t).unwrap_or(1);
let spinner = ui::Spinner::start(&self.active_provider, &self.active_model);
let mut stream = self.run_stream(input.clone(), true);
use futures::StreamExt;
let mut response_text = String::new();
let mut has_tokens = false;
let mut line_start = false;
let mut turn_prompt = 0u64;
let mut turn_completion = 0u64;
loop {
match stream.next().await {
Some(StreamEvent::Token(token)) => {
if !has_tokens {
spinner.stop();
has_tokens = true;
line_start = true;
}
for ch in token.chars() {
if ch == '\n' {
if line_start {
let w = ui::terminal_width();
println!("{}", style(format!(" │{:width$}│", "", width = w.saturating_sub(4))).dim());
} else {
let w = ui::terminal_width();
let content = std::mem::take(&mut response_text);
let last_line = content.rsplit('\n').next().unwrap_or("");
let padded = format!("{:<width$}", last_line, width = w.saturating_sub(4));
print!("\r"); println!("{}", style(format!(" │ {}│", padded)).dim());
response_text = content; }
response_text.push('\n');
line_start = true;
} else {
if line_start {
print!("{}", style("│ ").dim());
line_start = false;
}
print!("{}", ch);
response_text.push(ch);
}
}
std::io::stdout().flush().ok();
}
Some(StreamEvent::Done) => {
if !has_tokens { spinner.stop(); }
if !line_start && !response_text.is_empty() {
let w = ui::terminal_width();
let last_line = response_text.rsplit('\n').next().unwrap_or("");
let padded = format!("{:<width$}", last_line, width = w.saturating_sub(4));
print!("\r");
println!("{}", style(format!(" │ {}│", padded)).dim());
}
ui::print_turn_separator(width);
break;
}
Some(StreamEvent::Error(e)) => {
spinner.stop();
println!();
ui::render_error(&e);
break;
}
Some(StreamEvent::ToolStart { name, args }) => {
if !has_tokens { spinner.stop(); has_tokens = true; }
println!();
line_start = true; ui::render_tool_line(&name, &args, "running");
std::io::stdout().flush().ok();
}
Some(StreamEvent::ToolEnd { name, success }) => {
ui::render_tool_line(&name, "", if success { "done" } else { "error" });
std::io::stdout().flush().ok();
if success {
if let Ok(mut history) = self.tool_usage_history.lock() {
let key = format!("{}", name);
history.push((key, name.clone()));
while history.len() > 30 { history.remove(0); }
}
}
}
Some(StreamEvent::Usage(u)) => {
if let Ok(mut session) = self.session_usage.lock() {
session.record_turn(u.prompt_tokens, u.completion_tokens);
turn_prompt = u.prompt_tokens;
turn_completion = u.completion_tokens;
}
}
Some(StreamEvent::Iteration(_)) => {}
None => break,
}
}
if turn_prompt == 0 && turn_completion == 0 {
let input_tokens = (input.len() / 4).max(1) as u64;
let output_tokens = (response_text.len() / 4).max(1) as u64;
if let Ok(mut session) = self.session_usage.lock() {
session.record_turn(input_tokens, output_tokens);
}
}
if let Ok(usage) = self.session_usage.lock() {
let session_cost = usage.session_cost(&self.active_model);
let elapsed_secs = self.session_start.map(|s| s.elapsed().as_secs()).unwrap_or(0);
ui::session_status_bar(
&self.active_provider,
&self.active_model,
usage.session_total,
session_cost,
elapsed_secs,
turn_number,
self.tools.len(),
);
}
}
ui::save_history(&mut rl);
if let Ok(usage) = self.session_usage.lock() {
let turn_count = self.turn_count.lock().map(|t| *t).unwrap_or(0);
if turn_count > 0 {
let elapsed = self.session_start.map(|s| s.elapsed()).unwrap_or_default();
let s = elapsed.as_secs();
let topics = extract_session_topics(&self.messages);
let topic_str = if topics.is_empty() {
String::new()
} else {
format!(". Topics: {}", topics.join(", "))
};
let session_summary = format!(
"Session: {} turns, {} tokens, {}. Provider: {}/{}{}",
turn_count, usage.session_total,
if s > 60 { format!("{}m {}s", s/60, s%60) } else { format!("{}s", s) },
self.active_provider, self.active_model, topic_str
);
let tags: Vec<String> = topics.iter().map(|t| t.to_lowercase()).collect();
if let Some(ref store_mutex) = self.memory_store {
if let Ok(store) = store_mutex.lock() {
let _ = store.save_memory(
"memory", &session_summary, "session",
if turn_count > 5 { 2 } else { 1 },
&tags,
);
let _ = store.decay_old_memories();
}
}
let cost = usage.session_cost(&self.active_model);
ui::render_session_summary(turn_count, usage.session_total, cost, elapsed);
let state = crate::session::SessionState {
last_provider: self.active_provider.clone(),
last_model: self.active_model.clone(),
last_session_turns: turn_count,
last_session_tokens: usage.session_total,
last_session_duration_secs: elapsed.as_secs(),
};
crate::session::save_session_state(&state);
}
}
println!();
}
#[allow(dead_code)]
async fn chat_raw(&mut self) {
loop {
let input = match ui::readline_raw() {
Some(s) => s,
None => { println!(); break; }
};
if input.is_empty() { continue; }
let cmd = input.to_lowercase();
match cmd.as_str() {
"exit" | "quit" | "/exit" => break,
"/reset" | "/clear" => { self.reset().await; ui::render_info("Conversation cleared"); continue; }
"/help" | "help" => { ui::show_help(); continue; }
_ => {}
}
if cmd.starts_with('/') { ui::render_info(&format!("Unknown '{}'. /help", cmd)); continue; }
let spinner = ui::Spinner::start(&self.active_provider, &self.active_model);
let mut stream = self.run_stream(input.clone(), true);
use futures::StreamExt;
let mut has_tokens = false;
loop {
match stream.next().await {
Some(StreamEvent::Token(token)) => {
if !has_tokens {
spinner.stop();
has_tokens = true;
println!(" {} {}", style("▎").blue().bold(), style("Cortex").bold().blue());
}
print!("{}", token);
std::io::stdout().flush().ok();
}
Some(StreamEvent::Done) => {
if !has_tokens { spinner.stop(); }
println!();
println!();
break;
}
Some(StreamEvent::Error(e)) => {
spinner.stop();
println!();
ui::render_error(&e);
break;
}
Some(StreamEvent::ToolStart { name, args }) => {
if !has_tokens { spinner.stop(); has_tokens = true; }
println!();
ui::render_tool_line(&name, &args, "running");
}
Some(StreamEvent::ToolEnd { name, success }) => {
ui::render_tool_line(&name, "", if success { "done" } else { "error" });
}
_ => {}
}
}
let msgs = self.messages.lock().await;
if let Some(last) = msgs.last() {
if let Some(ref content) = last.content {
if !content.trim().is_empty() && last.role == "assistant" {
println!("{}", content.trim());
}
}
}
}
}
}
fn execute_tool(
tool_call: &ToolCall,
tools: &Arc<HashMap<String, Arc<ToolSpec>>>,
memory_store: &Option<Arc<std::sync::Mutex<MemoryStore>>>,
verbose: bool,
tool_errors: &std::sync::Mutex<HashMap<String, Vec<String>>>,
) -> String {
let tool = match tools.get(&tool_call.name) {
Some(t) => t,
None => return format!("Error: unknown tool '{}'. Available: {}", tool_call.name, tools.keys().cloned().collect::<Vec<_>>().join(", ")),
};
if verbose { eprintln!(" [tool] {}({:?})", tool_call.name, tool_call.arguments); }
let mut last_error = String::new();
let max_attempts = 3;
for attempt in 0..max_attempts {
if attempt > 0 {
let delay_ms = 250 * (2_u64.pow(attempt as u32));
if verbose { eprintln!(" [retry] attempt {} waiting {}ms", attempt + 1, delay_ms); }
std::thread::sleep(std::time::Duration::from_millis(delay_ms.min(4000)));
}
match tool.call(tool_call.arguments.clone()) {
Ok(result) => return result,
Err(e) => {
last_error = e;
if verbose { eprintln!(" [error] attempt {}: {}", attempt + 1, last_error); }
if let Some(ref store_mutex) = memory_store {
if let Ok(store) = store_mutex.lock() {
if let Ok(lessons) = store.search_lessons(&format!("{} {}", tool_call.name, last_error), 1) {
if lessons.first().map_or(false, |l| l.resolved && !l.fix.is_empty()) {
if verbose { eprintln!(" [self-heal] applying lesson"); }
continue;
}
}
}
}
if last_error.to_lowercase().contains("timeout") && attempt < max_attempts - 1 {
let mut args = tool_call.arguments.clone();
if let Some(obj) = args.as_object_mut() {
let current = obj.get("timeout").and_then(|v| v.as_u64()).unwrap_or(10);
obj.insert("timeout".into(), serde_json::Value::Number((current * 2).into()));
}
let retry_result = tool.call(args);
if retry_result.is_ok() { return String::from("(retry succeeded)"); }
}
}
}
}
if let Some(ref store_mutex) = memory_store {
if let Ok(store) = store_mutex.lock() {
let context = serde_json::to_string(&tool_call.arguments).unwrap_or_default();
let trigger = format!("{} failed: {}", tool_call.name, last_error);
let existing = store.search_lessons(&trigger, 1).ok();
let already_saved = existing.map_or(false, |e| !e.is_empty());
if !already_saved {
let _ = store.save_lesson(&trigger, "", &context, false);
}
}
}
if let Ok(mut errors) = tool_errors.lock() { errors.entry(tool_call.name.clone()).or_default().push(last_error.clone()); }
format!("Error executing {} after retry: {}", tool_call.name, last_error)
}
fn save_api_key_to_config(provider_name: &str, api_key: &str) -> Result<(), String> {
let config_paths = [
"config.yaml",
&format!("{}/.cortex/config.yaml", std::env::var("HOME").unwrap_or_default()),
];
for path in &config_paths {
let content = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(_) => continue,
};
let marker = format!("{}:\n api_key:", provider_name);
let new_content = if content.contains(&marker) {
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut in_target = false;
for line in &lines {
if line.trim().starts_with(&format!("{}:", provider_name)) && !line.trim().starts_with('#') {
in_target = true;
} else if in_target && line.trim().starts_with(|c: char| c.is_alphanumeric()) && line.contains(':') {
if !line.trim_start().starts_with("api_key:") {
in_target = false;
}
}
if in_target && line.trim().starts_with("api_key:") {
result.push(format!(" api_key: {}", api_key));
in_target = false;
} else {
result.push(line.to_string());
}
}
result.join("\n")
} else {
let mut lines = content.trim_end().to_string();
lines.push_str(&format!("\n {}:\n api_key: {}\n base_url: https://api.openai.com/v1\n", provider_name, api_key));
lines
};
std::fs::write(path, &new_content).map_err(|e| format!("Failed to write config: {}", e))?;
return Ok(());
}
Err("No config.yaml found to save to.".into())
}
#[allow(dead_code, unused_assignments)]
fn save_model_to_config(provider_name: &str, model: &str) -> Result<(), String> {
let config_paths = [
"config.yaml",
&format!("{}/.cortex/config.yaml", std::env::var("HOME").unwrap_or_default()),
];
for path in &config_paths {
let content = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(_) => continue,
};
let model_marker = format!("{}:\n model:", provider_name);
let new_content = if content.contains(&model_marker) {
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut in_target = false;
for line in &lines {
if line.trim().starts_with(&format!("{}:", provider_name)) && !line.trim().starts_with('#') {
in_target = true;
} else if in_target && line.trim().starts_with(|c: char| c.is_alphanumeric()) && line.contains(':') {
if !line.trim_start().starts_with("model:") && !line.trim_start().starts_with("api_key:") && !line.trim_start().starts_with("base_url:") {
in_target = false;
}
}
if in_target && line.trim().starts_with("model:") {
result.push(format!(" model: {}", model));
in_target = false;
} else {
result.push(line.to_string());
}
}
result.join("\n")
} else {
let api_key_marker = format!("{}:\n api_key:", provider_name);
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut added = false;
for line in &lines {
result.push(line.to_string());
if line.contains(&api_key_marker.trim_start()) && !added {
}
if line.trim().starts_with("api_key:") && line.contains(": ") && !added {
let idx = result.len() - 1;
if idx >= 1 && !result[idx - 1].contains("model:") {
}
}
}
if !added {
result = lines.iter().map(|l| l.to_string()).collect();
for (i, line) in lines.iter().enumerate() {
if line.trim().starts_with("api_key:") && result.len() > i {
if i > 0 && lines[i-1].trim() == format!("{}:", provider_name) {
result.insert(i + 1, format!(" model: {}", model));
added = true;
break;
}
}
}
}
result.join("\n")
};
std::fs::write(path, &new_content).map_err(|e| format!("Failed to write config: {}", e))?;
return Ok(());
}
Err("No config.yaml found to save to.".into())
}
fn save_active_config(provider: &str, model: &str) -> Result<(), String> {
let home = std::env::var("HOME").unwrap_or_default();
let config_paths = ["config.yaml", &format!("{}/.cortex/config.yaml", home)];
for path in &config_paths {
let content = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(_) => continue,
};
let mut lines: Vec<String> = content.lines().map(|l| l.to_string()).collect();
let mut changed = false;
for i in 0..lines.len() {
let trimmed = lines[i].trim();
if trimmed.starts_with("active_provider:") {
let indent = &lines[i][..lines[i].len() - lines[i].trim_start().len()];
lines[i] = format!("{}active_provider: {}", indent, provider);
changed = true;
} else if trimmed.starts_with("active_model:") {
let indent = &lines[i][..lines[i].len() - lines[i].trim_start().len()];
lines[i] = format!("{}active_model: {}", indent, model);
changed = true;
}
}
if changed {
let new_content = lines.join("\n");
std::fs::write(path, &new_content).map_err(|e| format!("Failed to write config: {}", e))?;
return Ok(());
}
}
Err("No config.yaml found with active_provider/active_model fields.".into())
}
async fn fetch_provider_models(base_url: &str, api_key: &str) -> Vec<String> {
if api_key.is_empty() || api_key.contains("${") {
return Vec::new();
}
let url = format!("{}/models", base_url.trim_end_matches('/'));
let client = reqwest::Client::new();
match client.get(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.timeout(std::time::Duration::from_secs(5))
.send()
.await
{
Ok(resp) if resp.status().is_success() => {
match resp.json::<serde_json::Value>().await {
Ok(data) => {
if let Some(models) = data["data"].as_array() {
let mut names: Vec<String> = models
.iter()
.filter_map(|m| m["id"].as_str().map(String::from))
.collect();
names.sort();
names.dedup();
return names;
}
Vec::new()
}
Err(_) => Vec::new(),
}
}
_ => Vec::new(),
}
}
fn extract_session_topics(messages: &Arc<tokio::sync::Mutex<Vec<Message>>>) -> Vec<String> {
let stopwords: std::collections::HashSet<&str> = [
"the","a","an","is","are","was","were","be","been","being","have","has","had",
"do","does","did","will","would","could","should","may","might","shall","can",
"to","of","in","for","on","with","at","by","from","as","into","through","during",
"before","after","above","below","between","out","off","over","under","again",
"further","then","once","here","there","when","where","why","how","all","each",
"every","both","few","more","most","other","some","such","no","nor","not","only",
"own","same","so","than","too","very","just","because","but","and","or","if",
"while","that","this","these","those","it","its","what","which","who","whom",
"about","up","down","like","also","get","got","tell","lets","let","know","make",
"doesnt","dont","cant","wont","please","help","need","want","see","use","say",
"hi","hello","hey","thanks","thank","yes","no","ok","okay","sure","right",
].iter().cloned().collect();
let msgs = messages.try_lock().ok();
let user_msgs: Vec<String> = match msgs {
Some(ref guard) => guard.iter()
.filter(|m| m.role == "user")
.map(|m| m.content.as_deref().unwrap_or("").to_string())
.collect(),
None => return vec![],
};
let mut word_counts: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
for msg in &user_msgs {
let clean: String = msg.chars()
.map(|c| if c.is_alphanumeric() || c.is_whitespace() { c } else { ' ' })
.collect();
for word in clean.split_whitespace() {
let lower = word.to_lowercase();
if lower.len() > 3 && !stopwords.contains(lower.as_str()) {
*word_counts.entry(lower).or_insert(0) += 1;
}
}
}
let mut topic_list: Vec<(String, usize)> = word_counts.into_iter()
.filter(|(_, count)| *count > 1) .collect();
topic_list.sort_by(|a, b| b.1.cmp(&a.1));
let mut topics: Vec<String> = topic_list.into_iter()
.take(5)
.map(|(word, _)| word)
.collect();
for topic in topics.iter_mut() {
if let Some(c) = topic.chars().next() {
*topic = c.to_uppercase().to_string() + &topic[1..];
}
}
topics
}