use super::{
PersonalFact, PersonalFactCategory, PersonalFactCollector, PersonalFactSource,
PersonalKnowledgeCache, PersonalKnowledgeSettings,
};
use anyhow::Result;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
const KNOWLEDGE_INFERENCE_INTERVAL_SECS: u64 = 300;
const KNOWLEDGE_SYNC_INTERVAL_SECS: u64 = 60;
pub struct PksIntegration {
settings: PersonalKnowledgeSettings,
collector: PersonalFactCollector,
cache: Option<Arc<Mutex<PersonalKnowledgeCache>>>,
tool_usage: ToolUsageTracker,
fact_tx: Option<mpsc::UnboundedSender<DetectedFact>>,
}
#[derive(Debug, Clone)]
pub struct DetectedFact {
pub fact: PersonalFact,
pub detection_source: DetectionSource,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DetectionSource {
ImplicitDetection,
BehavioralInference,
ServerSync,
}
impl PksIntegration {
pub fn new(settings: PersonalKnowledgeSettings) -> Self {
let collector = PersonalFactCollector::new(
settings.implicit_detection_confidence,
settings.enable_implicit_learning,
);
Self {
settings,
collector,
cache: None,
tool_usage: ToolUsageTracker::new(),
fact_tx: None,
}
}
pub fn with_cache(mut self, cache: Arc<Mutex<PersonalKnowledgeCache>>) -> Self {
self.cache = Some(cache);
self
}
pub fn with_fact_channel(mut self, tx: mpsc::UnboundedSender<DetectedFact>) -> Self {
self.fact_tx = Some(tx);
self
}
pub fn process_user_message(&mut self, message: &str) -> usize {
if !self.settings.enabled || !self.settings.enable_implicit_learning {
return 0;
}
let facts = self.collector.process_message(message);
let count = facts.len();
for fact in facts {
self.emit_fact(fact, DetectionSource::ImplicitDetection);
}
count
}
pub fn record_tool_usage(&mut self, tool_name: &str, success: bool) {
if !self.settings.enabled || !self.settings.enable_observed_learning {
return;
}
self.tool_usage.record(tool_name, success);
if let Some(facts) = self.tool_usage.infer_facts() {
for fact in facts {
self.emit_fact(fact, DetectionSource::BehavioralInference);
}
}
}
pub fn record_working_directory(&mut self, path: &str) {
if !self.settings.enabled || !self.settings.enable_observed_learning {
return;
}
if let Some(project_name) = extract_project_name(path) {
let fact = PersonalFact::new(
PersonalFactCategory::Context,
"current_project".to_string(),
project_name,
Some(format!("Working directory: {}", path)),
PersonalFactSource::SystemObserved,
false,
);
self.emit_fact(fact, DetectionSource::BehavioralInference);
}
}
fn emit_fact(&mut self, fact: PersonalFact, source: DetectionSource) {
if let Some(ref tx) = self.fact_tx {
let detected = DetectedFact {
fact: fact.clone(),
detection_source: source,
};
let _ = tx.send(detected);
}
if let Some(ref cache) = self.cache
&& let Ok(mut cache) = cache.lock()
&& let Err(e) = cache.upsert_fact(fact)
{
tracing::warn!("Failed to store detected fact: {}", e);
}
}
pub fn is_enabled(&self) -> bool {
self.settings.enabled
}
pub fn settings(&self) -> &PersonalKnowledgeSettings {
&self.settings
}
}
impl Default for PksIntegration {
fn default() -> Self {
Self::new(PersonalKnowledgeSettings::default())
}
}
pub struct ToolUsageTracker {
usage: HashMap<String, (u32, u32)>,
last_inference: Instant,
inference_interval: Duration,
min_uses_for_inference: u32,
}
impl ToolUsageTracker {
fn new() -> Self {
Self {
usage: HashMap::new(),
last_inference: Instant::now(),
inference_interval: Duration::from_secs(KNOWLEDGE_INFERENCE_INTERVAL_SECS), min_uses_for_inference: 5,
}
}
fn record(&mut self, tool_name: &str, success: bool) {
let entry = self.usage.entry(tool_name.to_string()).or_insert((0, 0));
if success {
entry.0 += 1;
} else {
entry.1 += 1;
}
}
fn infer_facts(&mut self) -> Option<Vec<PersonalFact>> {
if self.last_inference.elapsed() < self.inference_interval {
return None;
}
let mut facts = Vec::new();
for (tool_name, (successes, _failures)) in &self.usage {
if *successes >= self.min_uses_for_inference {
let category = categorize_tool(tool_name);
let key = format!("preferred_{}_tool", category);
let fact = PersonalFact::new(
PersonalFactCategory::Preference,
key,
tool_name.clone(),
Some(format!("Used {} times successfully", successes)),
PersonalFactSource::SystemObserved,
false,
);
facts.push(fact);
}
}
let file_ops =
self.count_category_usage(&["read_file", "write_file", "edit_file", "glob", "grep"]);
if file_ops >= self.min_uses_for_inference {
facts.push(PersonalFact::new(
PersonalFactCategory::Capability,
"file_operations_proficiency".to_string(),
"proficient".to_string(),
Some(format!("Completed {} file operations", file_ops)),
PersonalFactSource::SystemObserved,
false,
));
}
let git_ops =
self.count_category_usage(&["git_status", "git_diff", "git_log", "git_commit"]);
if git_ops >= self.min_uses_for_inference {
facts.push(PersonalFact::new(
PersonalFactCategory::Capability,
"git_proficiency".to_string(),
"proficient".to_string(),
Some(format!("Completed {} git operations", git_ops)),
PersonalFactSource::SystemObserved,
false,
));
}
if !facts.is_empty() {
self.last_inference = Instant::now();
Some(facts)
} else {
None
}
}
fn count_category_usage(&self, tools: &[&str]) -> u32 {
tools
.iter()
.filter_map(|t| self.usage.get(*t))
.map(|(s, _)| s)
.sum()
}
}
fn categorize_tool(tool_name: &str) -> &'static str {
match tool_name {
"read_file" | "write_file" | "edit_file" | "glob" | "grep" => "file",
"bash" | "execute_command" => "shell",
"git_status" | "git_diff" | "git_log" | "git_commit" => "git",
"web_search" | "fetch_url" => "web",
"semantic_search" | "context_recall" => "search",
_ => "general",
}
}
fn extract_project_name(path: &str) -> Option<String> {
use std::path::Path;
let path = Path::new(path);
let indicators = [
"Cargo.toml",
"package.json",
"pyproject.toml",
"go.mod",
".git",
];
for indicator in &indicators {
if path.join(indicator).exists() {
return path
.file_name()
.and_then(|n| n.to_str())
.map(|s| s.to_string());
}
}
let name = path.file_name()?.to_str()?;
if !["home", "usr", "var", "tmp", "etc", "Users", "root"].contains(&name) {
Some(name.to_string())
} else {
None
}
}
pub struct PksSseListener {
api_client: super::api::PersonalKnowledgeApiClient,
fact_tx: mpsc::UnboundedSender<DetectedFact>,
cache: Option<Arc<Mutex<PersonalKnowledgeCache>>>,
shutdown_rx: Option<tokio::sync::oneshot::Receiver<()>>,
sync_interval: Duration,
last_sync: Option<String>,
}
impl PksSseListener {
pub fn new(server_url: &str, fact_tx: mpsc::UnboundedSender<DetectedFact>) -> Self {
let api_client = super::api::PersonalKnowledgeApiClient::new(server_url);
Self {
api_client,
fact_tx,
cache: None,
shutdown_rx: None,
sync_interval: Duration::from_secs(KNOWLEDGE_SYNC_INTERVAL_SECS),
last_sync: None,
}
}
pub fn with_auth(mut self, token: String) -> Self {
self.api_client.set_auth_token(token);
self
}
pub fn with_cache(mut self, cache: Arc<Mutex<PersonalKnowledgeCache>>) -> Self {
self.cache = Some(cache);
self
}
pub fn with_interval(mut self, interval: Duration) -> Self {
self.sync_interval = interval;
self
}
pub fn with_shutdown(mut self, rx: tokio::sync::oneshot::Receiver<()>) -> Self {
self.shutdown_rx = Some(rx);
self
}
pub async fn listen(mut self) -> Result<()> {
tracing::info!(
"Starting PKS background sync (interval: {:?})",
self.sync_interval
);
let mut shutdown_rx = self.shutdown_rx.take();
let mut interval = tokio::time::interval(self.sync_interval);
if let Err(e) = self.perform_sync().await {
tracing::debug!("Initial PKS sync failed (may not be logged in): {}", e);
}
loop {
tokio::select! {
_ = async {
if let Some(ref mut rx) = shutdown_rx {
rx.await.ok();
} else {
std::future::pending::<()>().await;
}
} => {
tracing::info!("PKS background sync shutting down");
break;
}
_ = interval.tick() => {
if let Err(e) = self.perform_sync().await {
tracing::debug!("PKS sync error: {}", e);
}
}
}
}
Ok(())
}
async fn perform_sync(&mut self) -> Result<()> {
let pending_facts: Vec<PersonalFact> = if let Some(ref cache) = self.cache {
if let Ok(cache) = cache.lock() {
cache
.pending_submissions()
.iter()
.map(|p| p.fact.clone())
.collect()
} else {
Vec::new()
}
} else {
Vec::new()
};
let sync_result = self
.api_client
.sync(
self.last_sync.as_deref(),
None, &pending_facts, &[], 0.5, 100, )
.await;
match sync_result {
Ok(result) => {
self.last_sync = Some(result.sync_timestamp.clone());
let received_count = result.facts.len();
let uploaded_count = pending_facts.len();
for fact in result.facts {
let detected = DetectedFact {
fact,
detection_source: DetectionSource::ServerSync,
};
if let Err(e) = self.fact_tx.send(detected) {
tracing::warn!("Failed to send synced fact to channel: {}", e);
}
}
if uploaded_count > 0
&& let Some(ref cache) = self.cache
&& let Ok(mut cache) = cache.lock()
&& let Err(e) = cache.clear_pending_submissions()
{
tracing::warn!("Failed to clear pending submissions: {}", e);
}
if received_count > 0 || uploaded_count > 0 {
tracing::debug!(
"PKS sync complete: received {} facts, uploaded {} facts",
received_count,
uploaded_count
);
}
Ok(())
}
Err(e) => Err(e),
}
}
}
pub struct PksBackgroundProcessor {
fact_rx: mpsc::UnboundedReceiver<DetectedFact>,
cache: Arc<Mutex<PersonalKnowledgeCache>>,
}
impl PksBackgroundProcessor {
pub fn new(
fact_rx: mpsc::UnboundedReceiver<DetectedFact>,
cache: Arc<Mutex<PersonalKnowledgeCache>>,
_settings: PersonalKnowledgeSettings,
) -> Self {
Self { fact_rx, cache }
}
pub async fn run(mut self) {
tracing::info!("PKS background processor started");
while let Some(detected) = self.fact_rx.recv().await {
if let Err(e) = self.process_fact(detected) {
tracing::warn!("Failed to process detected fact: {}", e);
}
}
tracing::info!("PKS background processor stopped");
}
fn process_fact(&self, detected: DetectedFact) -> Result<()> {
let mut cache = self
.cache
.lock()
.map_err(|e| anyhow::anyhow!("Cache lock error: {}", e))?;
if let Some(existing) = cache.get_fact_by_key(&detected.fact.key) {
match detected.detection_source {
DetectionSource::ServerSync => {
if !existing.local_only {
cache.upsert_fact(detected.fact)?;
}
}
DetectionSource::ImplicitDetection | DetectionSource::BehavioralInference => {
if existing.source == detected.fact.source {
use super::PersonalFactFeedback;
let feedback = PersonalFactFeedback {
fact_id: existing.id.clone(),
is_reinforcement: true,
context: Some(format!(
"Detected again via {:?}",
detected.detection_source
)),
timestamp: chrono::Utc::now().timestamp(),
};
let _ = cache.queue_feedback(feedback);
}
}
}
} else {
cache.upsert_fact(detected.fact)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pks_integration_creation() {
let integration = PksIntegration::default();
assert!(integration.is_enabled());
}
#[test]
fn test_process_user_message() {
let mut integration = PksIntegration::default();
let count = integration.process_user_message("My name is John Smith");
assert!(count > 0);
}
#[test]
fn test_process_user_message_disabled() {
let settings = PersonalKnowledgeSettings {
enable_implicit_learning: false,
..Default::default()
};
let mut integration = PksIntegration::new(settings);
let count = integration.process_user_message("My name is John Smith");
assert_eq!(count, 0);
}
#[test]
fn test_tool_usage_tracking() {
let mut integration = PksIntegration::default();
for _ in 0..6 {
integration.record_tool_usage("read_file", true);
}
assert!(integration.tool_usage.usage.get("read_file").is_some());
}
#[test]
fn test_extract_project_name() {
let result = extract_project_name("/home/user/invalid");
assert!(result.is_some()); }
#[test]
fn test_categorize_tool() {
assert_eq!(categorize_tool("read_file"), "file");
assert_eq!(categorize_tool("bash"), "shell");
assert_eq!(categorize_tool("git_status"), "git");
assert_eq!(categorize_tool("unknown_tool"), "general");
}
}