use std::sync::Arc;
use fuzzy_parser::distance::{find_closest, Algorithm};
use fuzzy_parser::{repair_object_fields, sanitize_json, ObjectSchema};
pub use swarm_engine_core::exploration::{
SelectionKind, StrategyAdvice, StrategyAdviceError, StrategyAdvisor, StrategyContext,
};
use crate::decider::{LlmDecider, LlmError, LoraConfig};
use crate::json_prompt::strategy_selection_template;
pub fn parse_selection_kind_fuzzy(s: &str) -> Option<SelectionKind> {
let upper = s.to_uppercase();
match upper.as_str() {
"FIFO" => return Some(SelectionKind::Fifo),
"UCB1" => return Some(SelectionKind::Ucb1),
"GREEDY" => return Some(SelectionKind::Greedy),
"THOMPSON" => return Some(SelectionKind::Thompson),
_ => {}
}
let candidates = ["FIFO", "UCB1", "Greedy", "Thompson"];
if let Some(m) = find_closest(s, candidates, 0.6, Algorithm::JaroWinkler) {
match m.candidate.as_str() {
"FIFO" => Some(SelectionKind::Fifo),
"UCB1" => Some(SelectionKind::Ucb1),
"Greedy" => Some(SelectionKind::Greedy),
"Thompson" => Some(SelectionKind::Thompson),
_ => None,
}
} else {
None
}
}
impl From<LlmError> for StrategyAdviceError {
fn from(e: LlmError) -> Self {
Self::LlmError(e.message().to_string())
}
}
#[derive(Debug, Clone, Default)]
pub struct StrategyPromptBuilder;
impl StrategyPromptBuilder {
pub fn new() -> Self {
Self
}
pub fn build(&self, ctx: &StrategyContext) -> String {
let depth_info = ctx
.avg_depth
.map(|d| format!(", depth={:.1}", d))
.unwrap_or_default();
let content = format!(
"Strategies: FIFO, UCB1, Greedy, Thompson\n\
Guidelines: visits<20→UCB1, failure>30%→Thompson, established+low failure→Greedy\n\
User: frontier={}, visits={}, failure={:.0}%{}, current={}",
ctx.frontier_count,
ctx.total_visits,
ctx.failure_rate * 100.0,
depth_info,
ctx.current_strategy,
);
strategy_selection_template().build(&content)
}
}
const STRATEGY_FIELDS: ObjectSchema =
ObjectSchema::new(&["strategy", "change", "confidence", "reason"]);
#[derive(Debug, Clone, Default)]
pub struct StrategyResponseParser;
impl StrategyResponseParser {
pub fn new() -> Self {
Self
}
pub fn parse(&self, response: &str) -> Result<StrategyAdvice, StrategyAdviceError> {
let json_str = self.extract_json(response)?;
let sanitized = sanitize_json(&json_str);
tracing::debug!(sanitized = %sanitized, "Sanitized strategy JSON");
self.parse_json(&sanitized)
}
fn extract_json(&self, text: &str) -> Result<String, StrategyAdviceError> {
if let Some(start) = text.find("```json") {
let content_start = start + 7;
let remaining = &text[content_start..];
if let Some(end) = remaining.find("```") {
let json = remaining[..end].trim();
if !json.is_empty() {
return Ok(json.to_string());
}
}
}
if let Some(json) = self.extract_balanced_json(text) {
return Ok(json);
}
if let Some(json) = self.extract_from_natural_language(text) {
tracing::debug!(fallback_json = %json, "Extracted strategy from natural language");
return Ok(json);
}
Err(StrategyAdviceError::ParseError(format!(
"No JSON found in response: {}",
text
)))
}
fn extract_from_natural_language(&self, text: &str) -> Option<String> {
let text_upper = text.to_uppercase();
let recommend_patterns = ["RECOMMEND", "SUGGEST", "USE ", "PREFER", "OPTIMAL", "BEST"];
let strategies = [
("THOMPSON", "Thompson"),
("UCB1", "UCB1"),
("UCB", "UCB1"),
("GREEDY", "Greedy"),
("FIFO", "FIFO"),
];
for pattern in &recommend_patterns {
if let Some(pos) = text_upper.find(pattern) {
let search_range = &text_upper[pos..std::cmp::min(pos + 50, text_upper.len())];
for (keyword, strategy) in &strategies {
if search_range.contains(keyword) {
return Some(format!(
r#"{{"strategy":"{}","change":true,"confidence":0.6,"reason":"Extracted from natural language response"}}"#,
strategy
));
}
}
}
}
let mut first_match: Option<(usize, &str)> = None;
for (keyword, strategy) in &strategies {
if let Some(pos) = text_upper.find(keyword) {
if first_match.is_none() || pos < first_match.unwrap().0 {
first_match = Some((pos, strategy));
}
}
}
first_match.map(|(_, strategy)| {
format!(
r#"{{"strategy":"{}","change":false,"confidence":0.5,"reason":"Extracted from natural language response"}}"#,
strategy
)
})
}
fn extract_balanced_json(&self, text: &str) -> Option<String> {
let start = text.find('{')?;
let chars: Vec<char> = text[start..].chars().collect();
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
for (i, &ch) in chars.iter().enumerate() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => depth += 1,
'}' if !in_string => {
depth -= 1;
if depth == 0 {
return Some(chars[..=i].iter().collect());
}
}
_ => {}
}
}
None
}
fn parse_json(&self, json: &str) -> Result<StrategyAdvice, StrategyAdviceError> {
let mut parsed: serde_json::Value = serde_json::from_str(json)
.map_err(|e| StrategyAdviceError::ParseError(format!("JSON parse error: {}", e)))?;
if let Some(obj) = parsed.as_object_mut() {
let corrections = repair_object_fields(obj, &STRATEGY_FIELDS, "$", &Default::default());
if !corrections.is_empty() {
tracing::debug!(
corrections = ?corrections.iter().map(|c| format!("{} → {}", c.original, c.corrected)).collect::<Vec<_>>(),
"Fuzzy repaired strategy field names"
);
}
}
let strategy_str = parsed["strategy"]
.as_str()
.ok_or_else(|| StrategyAdviceError::ParseError("Missing 'strategy' field".into()))?;
let recommended = parse_selection_kind_fuzzy(strategy_str).ok_or_else(|| {
StrategyAdviceError::ParseError(format!("Unknown strategy: {}", strategy_str))
})?;
let should_change = parsed["change"].as_bool().unwrap_or(false);
let confidence = parsed["confidence"].as_f64().unwrap_or(0.5).clamp(0.0, 1.0);
let reason = parsed["reason"]
.as_str()
.unwrap_or("No reason provided")
.to_string();
Ok(StrategyAdvice {
recommended,
should_change,
reason,
confidence,
})
}
}
pub struct LlmStrategyAdvisor {
decider: Arc<dyn LlmDecider>,
runtime: tokio::runtime::Handle,
prompt_builder: StrategyPromptBuilder,
response_parser: StrategyResponseParser,
confidence_threshold: f64,
lora: Option<LoraConfig>,
}
impl LlmStrategyAdvisor {
pub fn new(decider: Arc<dyn LlmDecider>, runtime: tokio::runtime::Handle) -> Self {
Self {
decider,
runtime,
prompt_builder: StrategyPromptBuilder::new(),
response_parser: StrategyResponseParser::new(),
confidence_threshold: 0.6,
lora: None,
}
}
pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
self.confidence_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn confidence_threshold(&self) -> f64 {
self.confidence_threshold
}
pub fn with_lora(mut self, lora: LoraConfig) -> Self {
self.lora = Some(lora);
self
}
pub fn lora(&self) -> Option<&LoraConfig> {
self.lora.as_ref()
}
}
impl StrategyAdvisor for LlmStrategyAdvisor {
fn advise(&self, context: &StrategyContext) -> Result<StrategyAdvice, StrategyAdviceError> {
let prompt = self.prompt_builder.build(context);
tracing::debug!(prompt = %prompt, "Strategy advisor prompt");
let lora = self.lora.as_ref();
let response = self
.runtime
.block_on(async { self.decider.call_raw(&prompt, lora).await })?;
tracing::debug!(response = %response, "Strategy advisor raw response");
let mut advice = self.response_parser.parse(&response)?;
if advice.confidence < self.confidence_threshold {
tracing::debug!(
confidence = advice.confidence,
threshold = self.confidence_threshold,
"Low confidence, not changing strategy"
);
advice.should_change = false;
advice.reason = format!(
"Low confidence ({:.2} < {:.2}): {}",
advice.confidence, self.confidence_threshold, advice.reason
);
}
tracing::info!(
recommended = %advice.recommended,
should_change = advice.should_change,
confidence = advice.confidence,
reason = %advice.reason,
"Strategy advice"
);
Ok(advice)
}
fn name(&self) -> &str {
"LlmStrategyAdvisor"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_selection_kind_display() {
assert_eq!(SelectionKind::Fifo.to_string(), "FIFO");
assert_eq!(SelectionKind::Ucb1.to_string(), "UCB1");
assert_eq!(SelectionKind::Greedy.to_string(), "Greedy");
assert_eq!(SelectionKind::Thompson.to_string(), "Thompson");
}
#[test]
fn test_selection_kind_from_str_exact() {
assert_eq!(
parse_selection_kind_fuzzy("FIFO"),
Some(SelectionKind::Fifo)
);
assert_eq!(
parse_selection_kind_fuzzy("UCB1"),
Some(SelectionKind::Ucb1)
);
assert_eq!(
parse_selection_kind_fuzzy("Greedy"),
Some(SelectionKind::Greedy)
);
assert_eq!(
parse_selection_kind_fuzzy("Thompson"),
Some(SelectionKind::Thompson)
);
}
#[test]
fn test_selection_kind_from_str_case_insensitive() {
assert_eq!(
parse_selection_kind_fuzzy("fifo"),
Some(SelectionKind::Fifo)
);
assert_eq!(
parse_selection_kind_fuzzy("ucb1"),
Some(SelectionKind::Ucb1)
);
assert_eq!(
parse_selection_kind_fuzzy("GREEDY"),
Some(SelectionKind::Greedy)
);
assert_eq!(
parse_selection_kind_fuzzy("THOMPSON"),
Some(SelectionKind::Thompson)
);
}
#[test]
fn test_selection_kind_from_str_fuzzy() {
assert_eq!(
parse_selection_kind_fuzzy("Thomspon"),
Some(SelectionKind::Thompson)
);
assert_eq!(
parse_selection_kind_fuzzy("Gredy"),
Some(SelectionKind::Greedy)
);
}
#[test]
fn test_selection_kind_from_str_invalid() {
assert_eq!(parse_selection_kind_fuzzy("Unknown"), None);
assert_eq!(parse_selection_kind_fuzzy("Random"), None);
}
#[test]
fn test_strategy_context_new() {
let ctx = StrategyContext::new(15, 47, 0.23, SelectionKind::Ucb1);
assert_eq!(ctx.frontier_count, 15);
assert_eq!(ctx.total_visits, 47);
assert!((ctx.failure_rate - 0.23).abs() < 0.001);
assert!((ctx.success_rate - 0.77).abs() < 0.001);
assert_eq!(ctx.current_strategy, SelectionKind::Ucb1);
assert!(ctx.avg_depth.is_none());
}
#[test]
fn test_strategy_context_with_depth() {
let ctx = StrategyContext::new(10, 100, 0.1, SelectionKind::Greedy).with_avg_depth(3.5);
assert_eq!(ctx.avg_depth, Some(3.5));
}
#[test]
fn test_strategy_advice_no_change() {
let advice = StrategyAdvice::no_change(SelectionKind::Ucb1, "Exploration phase");
assert_eq!(advice.recommended, SelectionKind::Ucb1);
assert!(!advice.should_change);
assert_eq!(advice.reason, "Exploration phase");
assert!((advice.confidence - 1.0).abs() < 0.001);
}
#[test]
fn test_strategy_advice_change_to() {
let advice = StrategyAdvice::change_to(SelectionKind::Greedy, "Patterns established", 0.85);
assert_eq!(advice.recommended, SelectionKind::Greedy);
assert!(advice.should_change);
assert_eq!(advice.reason, "Patterns established");
assert!((advice.confidence - 0.85).abs() < 0.001);
}
#[test]
fn test_prompt_builder_basic() {
let builder = StrategyPromptBuilder::new();
let ctx = StrategyContext::new(15, 47, 0.23, SelectionKind::Ucb1);
let prompt = builder.build(&ctx);
assert!(prompt.contains("Example interaction:"));
assert!(prompt.contains("Your JSON:"));
assert!(prompt.contains("frontier=15"));
assert!(prompt.contains("visits=47"));
assert!(prompt.contains("failure=23%"));
assert!(prompt.contains("current=UCB1"));
assert!(prompt.contains("FIFO"));
assert!(prompt.contains("Greedy"));
assert!(prompt.contains("Thompson"));
}
#[test]
fn test_prompt_builder_with_depth() {
let builder = StrategyPromptBuilder::new();
let ctx = StrategyContext::new(10, 100, 0.1, SelectionKind::Greedy).with_avg_depth(3.5);
let prompt = builder.build(&ctx);
assert!(prompt.contains("depth=3.5"));
}
#[test]
fn test_parse_valid_json() {
let parser = StrategyResponseParser::new();
let response = r#"{"strategy": "Greedy", "change": true, "confidence": 0.85, "reason": "Low failure rate"}"#;
let advice = parser.parse(response).unwrap();
assert_eq!(advice.recommended, SelectionKind::Greedy);
assert!(advice.should_change);
assert!((advice.confidence - 0.85).abs() < 0.001);
assert_eq!(advice.reason, "Low failure rate");
}
#[test]
fn test_parse_json_with_prefix() {
let parser = StrategyResponseParser::new();
let response = r#"Based on the analysis: {"strategy": "Thompson", "change": true, "confidence": 0.7, "reason": "High variance"}"#;
let advice = parser.parse(response).unwrap();
assert_eq!(advice.recommended, SelectionKind::Thompson);
}
#[test]
fn test_parse_json_markdown_block() {
let parser = StrategyResponseParser::new();
let response = r#"```json
{"strategy": "UCB1", "change": false, "confidence": 0.9, "reason": "Still exploring"}
```"#;
let advice = parser.parse(response).unwrap();
assert_eq!(advice.recommended, SelectionKind::Ucb1);
assert!(!advice.should_change);
}
#[test]
fn test_parse_json_typo_repair() {
let parser = StrategyResponseParser::new();
let response =
r#"{"straegy": "Greedy", "change": true, "confidnce": 0.8, "reason": "test"}"#;
let advice = parser.parse(response).unwrap();
assert_eq!(advice.recommended, SelectionKind::Greedy);
}
#[test]
fn test_parse_json_strategy_typo() {
let parser = StrategyResponseParser::new();
let response =
r#"{"strategy": "Thomspon", "change": true, "confidence": 0.75, "reason": "variance"}"#;
let advice = parser.parse(response).unwrap();
assert_eq!(advice.recommended, SelectionKind::Thompson);
}
#[test]
fn test_parse_json_defaults() {
let parser = StrategyResponseParser::new();
let response = r#"{"strategy": "FIFO", "reason": "simple"}"#;
let advice = parser.parse(response).unwrap();
assert_eq!(advice.recommended, SelectionKind::Fifo);
assert!(!advice.should_change); assert!((advice.confidence - 0.5).abs() < 0.001); }
#[test]
fn test_parse_json_missing_strategy() {
let parser = StrategyResponseParser::new();
let response = r#"{"change": true, "confidence": 0.8}"#;
let result = parser.parse(response);
assert!(result.is_err());
assert!(matches!(result, Err(StrategyAdviceError::ParseError(_))));
}
#[test]
fn test_parse_no_json() {
let parser = StrategyResponseParser::new();
let response = "This is just plain text without any JSON.";
let result = parser.parse(response);
assert!(result.is_err());
assert!(matches!(result, Err(StrategyAdviceError::ParseError(_))));
}
#[test]
fn test_parse_confidence_clamping() {
let parser = StrategyResponseParser::new();
let response =
r#"{"strategy": "Greedy", "change": true, "confidence": 1.5, "reason": "test"}"#;
let advice = parser.parse(response).unwrap();
assert!((advice.confidence - 1.0).abs() < 0.001);
let response =
r#"{"strategy": "Greedy", "change": true, "confidence": -0.5, "reason": "test"}"#;
let advice = parser.parse(response).unwrap();
assert!((advice.confidence - 0.0).abs() < 0.001);
}
}