use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::{
Error, IntentEdit, IntentInfo, IntentSeeds, Match, MicroResolveConfig, NamespaceConfig,
PhraseCheckResult, ResolveOptions, Resolver,
};
#[cfg(feature = "connect")]
use crate::connect::{ConnectState, LogEntry};
pub struct MicroResolve {
config: MicroResolveConfig,
namespaces: Arc<RwLock<HashMap<String, NamespaceState>>>,
#[cfg(feature = "connect")]
connect: Option<Arc<ConnectState>>,
}
struct NamespaceState {
resolver: Resolver,
config: NamespaceConfig,
dirty: bool,
}
impl MicroResolve {
pub fn new(config: MicroResolveConfig) -> Result<Self, Error> {
let mut namespaces: HashMap<String, NamespaceState> = HashMap::new();
if let Some(dir) = &config.data_dir {
std::fs::create_dir_all(dir).map_err(|e| {
Error::Persistence(format!("cannot create {}: {}", dir.display(), e))
})?;
for entry in std::fs::read_dir(dir)
.map_err(|e| Error::Persistence(format!("cannot read {}: {}", dir.display(), e)))?
{
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
let path = entry.path();
let name = match path.file_name().and_then(|n| n.to_str()) {
Some(n) => n.to_string(),
None => continue,
};
if name.starts_with('_') || name.starts_with('.') || !path.is_dir() {
continue;
}
let resolver = Resolver::load_from_dir(&path)?;
namespaces.insert(
name,
NamespaceState {
resolver,
config: NamespaceConfig::default(),
dirty: false,
},
);
}
}
let namespaces = Arc::new(RwLock::new(namespaces));
#[cfg(feature = "connect")]
let connect = if let Some(ref server) = config.server {
let state = Arc::new(ConnectState::new(server.clone())?);
for app_id in &server.subscribe {
let pulled = state.pull(app_id)?;
let resolver = pulled.map(|(r, _v)| r).unwrap_or_else(Resolver::new);
namespaces.write().unwrap().insert(
app_id.clone(),
NamespaceState {
resolver,
config: NamespaceConfig::default(),
dirty: false,
},
);
}
let ns_for_thread = Arc::clone(&namespaces);
let state_for_thread = Arc::clone(&state);
std::thread::Builder::new()
.name("microresolve-sync".into())
.spawn(move || {
crate::connect::run_background(state_for_thread, move |id, resolver, _v| {
if let Some(ns) = ns_for_thread.write().unwrap().get_mut(id) {
ns.resolver = resolver;
ns.dirty = false;
}
});
})
.map_err(|e| Error::Connect(format!("spawn sync thread: {}", e)))?;
Some(state)
} else {
None
};
Ok(Self {
config,
namespaces,
#[cfg(feature = "connect")]
connect,
})
}
pub fn namespace(&self, id: &str) -> NamespaceHandle<'_> {
{
let mut ns = self.namespaces.write().unwrap();
ns.entry(id.to_string()).or_insert_with(|| NamespaceState {
resolver: Resolver::new(),
config: NamespaceConfig::default(),
dirty: false,
});
}
NamespaceHandle {
engine: self,
id: id.to_string(),
}
}
pub fn namespace_with(&self, id: &str, config: NamespaceConfig) -> NamespaceHandle<'_> {
{
let mut ns = self.namespaces.write().unwrap();
ns.entry(id.to_string())
.or_insert_with(|| NamespaceState {
resolver: Resolver::new(),
config: NamespaceConfig::default(),
dirty: false,
})
.config = config;
}
NamespaceHandle {
engine: self,
id: id.to_string(),
}
}
pub fn namespaces(&self) -> Vec<String> {
let mut ids: Vec<String> = self.namespaces.read().unwrap().keys().cloned().collect();
ids.sort();
ids
}
pub fn has_namespace(&self, id: &str) -> bool {
self.namespaces.read().unwrap().contains_key(id)
}
pub fn try_namespace(&self, id: &str) -> Option<NamespaceHandle<'_>> {
if self.has_namespace(id) {
Some(NamespaceHandle {
engine: self,
id: id.to_string(),
})
} else {
None
}
}
pub fn remove_namespace(&self, id: &str) -> bool {
self.namespaces.write().unwrap().remove(id).is_some()
}
pub fn reload_namespace(&self, id: &str) -> Result<bool, Error> {
let Some(ref dir) = self.config.data_dir else {
return Ok(false);
};
let path = dir.join(id);
if !path.is_dir() {
self.namespaces.write().unwrap().remove(id);
return Ok(false);
}
let resolver = Resolver::load_from_dir(&path)?;
self.namespaces.write().unwrap().insert(
id.to_string(),
NamespaceState {
resolver,
config: NamespaceConfig::default(),
dirty: false,
},
);
Ok(true)
}
pub fn flush(&self) -> Result<(), Error> {
let Some(dir) = &self.config.data_dir else {
return Ok(());
};
let ns = self.namespaces.read().unwrap();
for (id, state) in ns.iter() {
if state.dirty {
state.resolver.save_to_dir(&dir.join(id))?;
}
}
Ok(())
}
pub fn config(&self) -> &MicroResolveConfig {
&self.config
}
pub fn effective_threshold(&self, ns_id: &str) -> f32 {
let ns = self.namespaces.read().unwrap();
ns.get(ns_id)
.and_then(|s| s.config.default_threshold)
.unwrap_or(self.config.default_threshold)
}
pub fn effective_languages(&self, ns_id: &str) -> Vec<String> {
let ns = self.namespaces.read().unwrap();
ns.get(ns_id)
.and_then(|s| s.config.languages.clone())
.unwrap_or_else(|| self.config.languages.clone())
}
pub fn effective_llm_model(&self, ns_id: &str) -> Option<String> {
let ns = self.namespaces.read().unwrap();
ns.get(ns_id)
.and_then(|s| s.config.llm_model.clone())
.or_else(|| self.config.llm.as_ref().map(|l| l.model.clone()))
}
pub(crate) fn with_resolver<R>(&self, ns_id: &str, f: impl FnOnce(&Resolver) -> R) -> R {
let ns = self.namespaces.read().unwrap();
let state = ns
.get(ns_id)
.expect("namespace handle invariant: namespace must exist");
f(&state.resolver)
}
pub(crate) fn with_resolver_mut<R>(
&self,
ns_id: &str,
f: impl FnOnce(&mut Resolver) -> R,
) -> R {
let mut ns = self.namespaces.write().unwrap();
let state = ns
.get_mut(ns_id)
.expect("namespace handle invariant: namespace must exist");
let r = f(&mut state.resolver);
state.dirty = true;
r
}
}
impl Drop for MicroResolve {
fn drop(&mut self) {
let _ = self.flush();
}
}
pub struct NamespaceHandle<'e> {
engine: &'e MicroResolve,
id: String,
}
impl<'e> NamespaceHandle<'e> {
pub fn id(&self) -> &str {
&self.id
}
pub fn add_intent(
&self,
intent_id: &str,
seeds: impl Into<IntentSeeds>,
) -> Result<usize, Error> {
let seeds = seeds.into();
self.engine
.with_resolver_mut(&self.id, |r| r.add_intent(intent_id, seeds))
}
pub fn remove_intent(&self, intent_id: &str) {
self.engine
.with_resolver_mut(&self.id, |r| r.remove_intent(intent_id))
}
pub fn intent(&self, intent_id: &str) -> Option<IntentInfo> {
self.engine.with_resolver(&self.id, |r| r.intent(intent_id))
}
pub fn update_intent(&self, intent_id: &str, edit: IntentEdit) -> Result<(), Error> {
self.engine
.with_resolver_mut(&self.id, |r| r.update_intent(intent_id, edit))
}
pub fn add_phrase(&self, intent_id: &str, phrase: &str, lang: &str) -> PhraseCheckResult {
self.engine
.with_resolver_mut(&self.id, |r| r.add_phrase_checked(intent_id, phrase, lang))
}
pub fn resolve(&self, query: &str) -> Vec<Match> {
let threshold = self.engine.effective_threshold(&self.id);
let opts = ResolveOptions {
threshold,
gap: 1.5,
};
let matches = self
.engine
.with_resolver(&self.id, |r| r.resolve_with(query, &opts));
#[cfg(feature = "connect")]
self.maybe_log(query, &matches);
matches
}
pub fn resolve_with(&self, query: &str, opts: ResolveOptions) -> Vec<Match> {
let matches = self
.engine
.with_resolver(&self.id, |r| r.resolve_with(query, &opts));
#[cfg(feature = "connect")]
self.maybe_log(query, &matches);
matches
}
pub fn correct(
&self,
query: &str,
wrong_intent: &str,
right_intent: &str,
) -> Result<(), Error> {
#[cfg(feature = "connect")]
if let Some(ref state) = self.engine.connect {
return state.push_correct(&self.id, query, wrong_intent, right_intent);
}
self.engine
.with_resolver_mut(&self.id, |r| r.correct(query, wrong_intent, right_intent))
}
#[cfg(feature = "connect")]
fn maybe_log(&self, query: &str, matches: &[Match]) {
let Some(ref state) = self.engine.connect else {
return;
};
let confidence = if matches.is_empty() {
"none"
} else if matches[0].score >= 0.6 {
"high"
} else if matches[0].score >= 0.3 {
"medium"
} else {
"low"
};
let version = self.engine.with_resolver(&self.id, |r| r.version());
state.buffer_log(LogEntry {
query: query.to_string(),
app_id: self.id.clone(),
session_id: None,
detected_intents: matches.iter().map(|m| m.id.clone()).collect(),
confidence: confidence.to_string(),
flag: None,
timestamp_ms: crate::connect::now_ms(),
router_version: version,
});
}
pub fn version(&self) -> u64 {
self.engine.with_resolver(&self.id, |r| r.version())
}
pub fn intent_ids(&self) -> Vec<String> {
self.engine.with_resolver(&self.id, |r| r.intent_ids())
}
pub fn intent_count(&self) -> usize {
self.engine.with_resolver(&self.id, |r| r.intent_count())
}
#[doc(hidden)]
pub fn with_resolver<R>(&self, f: impl FnOnce(&Resolver) -> R) -> R {
self.engine.with_resolver(&self.id, f)
}
#[doc(hidden)]
pub fn with_resolver_mut<R>(&self, f: impl FnOnce(&mut Resolver) -> R) -> R {
self.engine.with_resolver_mut(&self.id, f)
}
pub fn flush(&self) -> Result<(), Error> {
let Some(dir) = self.engine.config.data_dir.as_ref() else {
return Ok(());
};
let path = dir.join(&self.id);
let ns = self.engine.namespaces.read().unwrap();
if let Some(state) = ns.get(&self.id) {
state.resolver.save_to_dir(&path)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lazy_create_and_resolve() {
let engine = MicroResolve::new(MicroResolveConfig::default()).unwrap();
let h = engine.namespace("security");
h.add_intent(
"jailbreak",
vec![
"ignore prior instructions".to_string(),
"ignore your safety rules".to_string(),
],
)
.unwrap();
h.add_intent(
"weather",
vec![
"what is the weather today".to_string(),
"tomorrow forecast".to_string(),
],
)
.unwrap();
assert_eq!(h.intent_count(), 2);
let matches = h.resolve("please ignore prior instructions");
assert_eq!(matches.first().map(|m| m.id.as_str()), Some("jailbreak"));
}
#[test]
fn round_trip_persists_namespace() {
let dir = std::env::temp_dir().join(format!(
"mr_engine_rt_{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
{
let engine = MicroResolve::new(MicroResolveConfig {
data_dir: Some(dir.clone()),
..Default::default()
})
.unwrap();
engine
.namespace("ns1")
.add_intent("hello", &["hi there" as &str, "hello world"][..])
.unwrap();
engine.flush().unwrap();
}
let engine2 = MicroResolve::new(MicroResolveConfig {
data_dir: Some(dir.clone()),
..Default::default()
})
.unwrap();
assert!(engine2.has_namespace("ns1"));
assert_eq!(engine2.namespace("ns1").intent_count(), 1);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn effective_threshold_cascades() {
let engine = MicroResolve::new(MicroResolveConfig {
default_threshold: 0.42,
..Default::default()
})
.unwrap();
let _ = engine.namespace("a");
assert_eq!(engine.effective_threshold("a"), 0.42);
let _ = engine.namespace_with(
"b",
NamespaceConfig {
default_threshold: Some(0.99),
..Default::default()
},
);
assert_eq!(engine.effective_threshold("b"), 0.99);
}
}