Skip to main content

brainwires_reasoning/
lib.rs

1#![deny(missing_docs)]
2//! # Brainwires Reasoning
3//!
4//! Layer 3 — Intelligence. Provider-agnostic reasoning primitives for the
5//! Brainwires Agent Framework.
6//!
7//! This crate owns:
8//!
9//! - **`plan_parser`** — extract numbered task steps from LLM plan output.
10//! - **`output_parser`** — parse structured data (JSON, regex) from raw LLM
11//!   text.
12//! - **Local inference scorers** — provider-agnostic TIER 1/2 components for
13//!   routing, validation, complexity scoring, summarisation, retrieval
14//!   gating, relevance ranking, strategy selection, and entity enhancement.
15//!   All accept `Arc<dyn Provider>` and fall back to pattern-based logic
16//!   when the provider is unavailable.
17//!
18//! ## Scope note — `prompting` stays in `brainwires-knowledge`
19//!
20//! The original architectural plan (`sleepy-popping-falcon.md`) called for
21//! `prompting` to move here too. It didn't — `prompting` is tightly coupled
22//! to `bks_pks` inside `brainwires-knowledge` and a move would have pulled
23//! an entire knowledge-store dependency into this crate. The deviation is
24//! intentional; tests and consumers of prompting should continue to target
25//! `brainwires_knowledge::prompting`.
26//!
27//! ## Configuration
28//!
29//! Enable via `LocalInferenceConfig`:
30//! ```toml
31//! [local_llm]
32//! enabled = true
33//! use_for_routing = true
34//! use_for_validation = true
35//! use_for_complexity = true
36//! use_for_summarization = true
37//! ```
38
39// ── Parsers (moved from brainwires-core) ─────────────────────────────────
40/// Structured output parsers for LLM responses.
41pub mod output_parser;
42/// Plan text parser for extracting steps from LLM output.
43pub mod plan_parser;
44
45// Flat re-exports for convenience.
46pub use output_parser::{JsonListParser, JsonOutputParser, OutputParser, RegexOutputParser};
47pub use plan_parser::{ParsedStep, parse_plan_steps, steps_to_tasks};
48
49// ── Scorers (moved from brainwires-agents::reasoning) ────────────────────
50mod complexity;
51mod entity_enhancer;
52mod relevance_scorer;
53mod retrieval_classifier;
54mod router;
55/// Reasoning strategies: CoT, ReAct, Reflexion, Tree-of-Thoughts.
56pub mod strategies;
57mod strategy_selector;
58mod summarizer;
59mod validator;
60
61pub use complexity::{ComplexityResult, ComplexityScorer, ComplexityScorerBuilder};
62pub use entity_enhancer::{
63    EnhancedEntity, EnhancedRelationship, EnhancementResult, EntityEnhancer, EntityEnhancerBuilder,
64    RelationType, SemanticEntityType,
65};
66pub use relevance_scorer::{RelevanceResult, RelevanceScorer, RelevanceScorerBuilder};
67pub use retrieval_classifier::{
68    ClassificationResult, RetrievalClassifier, RetrievalClassifierBuilder,
69    RetrievalNeed as LocalRetrievalNeed,
70};
71pub use router::{LocalRouter, LocalRouterBuilder, RouteResult};
72pub use strategies::{
73    ChainOfThoughtStrategy, ReActStrategy, ReasoningStrategy, ReflexionStrategy, StrategyPreset,
74    StrategyStep, TreeOfThoughtsStrategy,
75};
76pub use strategy_selector::{
77    RecommendedStrategy, StrategyResult, StrategySelector, StrategySelectorBuilder, TaskType,
78};
79pub use summarizer::{
80    ExtractedFact, FactCategory, LocalSummarizer, LocalSummarizerBuilder, SummarizationResult,
81};
82pub use validator::{LocalValidator, LocalValidatorBuilder, ValidationResult};
83
84use std::time::Instant;
85use tracing::{info, warn};
86
87/// Configuration for local inference components.
88#[derive(Clone, Debug)]
89pub struct LocalInferenceConfig {
90    // TIER 1: Quick Wins
91    /// Enable local routing
92    pub routing_enabled: bool,
93    /// Enable local validation
94    pub validation_enabled: bool,
95    /// Enable complexity scoring
96    pub complexity_enabled: bool,
97
98    // TIER 2: Context & Retrieval
99    /// Enable local summarization for tiered memory
100    pub summarization_enabled: bool,
101    /// Enable local retrieval gating
102    pub retrieval_gating_enabled: bool,
103    /// Enable local relevance scoring
104    pub relevance_scoring_enabled: bool,
105    /// Enable local strategy selection
106    pub strategy_selection_enabled: bool,
107    /// Enable local entity enhancement
108    pub entity_enhancement_enabled: bool,
109
110    // Model selection per task
111    /// Model ID to use for routing (fast model preferred)
112    pub routing_model: Option<String>,
113    /// Model ID to use for validation (fast model preferred)
114    pub validation_model: Option<String>,
115    /// Model ID to use for complexity scoring (fast model preferred)
116    pub complexity_model: Option<String>,
117    /// Model ID to use for summarization (larger model preferred)
118    pub summarization_model: Option<String>,
119    /// Model ID to use for retrieval classification (fast model preferred)
120    pub retrieval_model: Option<String>,
121    /// Model ID to use for relevance scoring (fast model preferred)
122    pub relevance_model: Option<String>,
123    /// Model ID to use for strategy selection (larger model preferred)
124    pub strategy_model: Option<String>,
125    /// Model ID to use for entity enhancement (fast model preferred)
126    pub entity_model: Option<String>,
127
128    /// Log all local inference calls
129    pub log_inference: bool,
130}
131
132impl Default for LocalInferenceConfig {
133    fn default() -> Self {
134        Self {
135            // TIER 1
136            routing_enabled: false,
137            validation_enabled: false,
138            complexity_enabled: false,
139            // TIER 2
140            summarization_enabled: false,
141            retrieval_gating_enabled: false,
142            relevance_scoring_enabled: false,
143            strategy_selection_enabled: false,
144            entity_enhancement_enabled: false,
145            // Model selection - TIER 1
146            routing_model: Some("lfm2-350m".to_string()),
147            validation_model: Some("lfm2-350m".to_string()),
148            complexity_model: Some("lfm2-350m".to_string()),
149            // Model selection - TIER 2
150            summarization_model: Some("lfm2-1.2b".to_string()),
151            retrieval_model: Some("lfm2-350m".to_string()),
152            relevance_model: Some("lfm2-350m".to_string()),
153            strategy_model: Some("lfm2-1.2b".to_string()),
154            entity_model: Some("lfm2-350m".to_string()),
155            log_inference: true,
156        }
157    }
158}
159
160impl LocalInferenceConfig {
161    /// Create a config with all TIER 1 features enabled
162    pub fn tier1_enabled() -> Self {
163        Self {
164            routing_enabled: true,
165            validation_enabled: true,
166            complexity_enabled: true,
167            ..Default::default()
168        }
169    }
170
171    /// Create a config with all TIER 2 features enabled
172    pub fn tier2_enabled() -> Self {
173        Self {
174            summarization_enabled: true,
175            retrieval_gating_enabled: true,
176            relevance_scoring_enabled: true,
177            strategy_selection_enabled: true,
178            entity_enhancement_enabled: true,
179            ..Default::default()
180        }
181    }
182
183    /// Create a config with all features enabled
184    pub fn all_enabled() -> Self {
185        Self {
186            routing_enabled: true,
187            validation_enabled: true,
188            complexity_enabled: true,
189            summarization_enabled: true,
190            retrieval_gating_enabled: true,
191            relevance_scoring_enabled: true,
192            strategy_selection_enabled: true,
193            entity_enhancement_enabled: true,
194            ..Default::default()
195        }
196    }
197
198    /// Create a config with only routing enabled
199    pub fn routing_only() -> Self {
200        Self {
201            routing_enabled: true,
202            ..Default::default()
203        }
204    }
205
206    /// Create a config with only validation enabled
207    pub fn validation_only() -> Self {
208        Self {
209            validation_enabled: true,
210            ..Default::default()
211        }
212    }
213
214    /// Create a config with only summarization enabled
215    pub fn summarization_only() -> Self {
216        Self {
217            summarization_enabled: true,
218            ..Default::default()
219        }
220    }
221}
222
223/// Log a local inference event.
224pub fn log_inference(task: &str, model: &str, latency_ms: u64, success: bool) {
225    if success {
226        info!(
227            target: "local_llm",
228            task = task,
229            model = model,
230            latency_ms = latency_ms,
231            "Local inference completed"
232        );
233    } else {
234        warn!(
235            target: "local_llm",
236            task = task,
237            model = model,
238            latency_ms = latency_ms,
239            "Local inference failed, falling back to pattern-based"
240        );
241    }
242}
243
244/// Measure inference latency.
245pub struct InferenceTimer {
246    start: Instant,
247    task: String,
248    model: String,
249}
250
251impl InferenceTimer {
252    /// Create a new inference timer for the given task and model.
253    pub fn new(task: impl Into<String>, model: impl Into<String>) -> Self {
254        Self {
255            start: Instant::now(),
256            task: task.into(),
257            model: model.into(),
258        }
259    }
260
261    /// Stop the timer and log the inference event.
262    pub fn finish(self, success: bool) {
263        let latency_ms = self.start.elapsed().as_millis() as u64;
264        log_inference(&self.task, &self.model, latency_ms, success);
265    }
266
267    /// Return the elapsed time in milliseconds since the timer was created.
268    pub fn elapsed_ms(&self) -> u64 {
269        self.start.elapsed().as_millis() as u64
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_config_default() {
279        let config = LocalInferenceConfig::default();
280        assert!(!config.routing_enabled);
281        assert!(!config.validation_enabled);
282        assert!(!config.complexity_enabled);
283        assert!(!config.summarization_enabled);
284        assert!(!config.retrieval_gating_enabled);
285        assert!(!config.relevance_scoring_enabled);
286    }
287
288    #[test]
289    fn test_config_tier1_enabled() {
290        let config = LocalInferenceConfig::tier1_enabled();
291        assert!(config.routing_enabled);
292        assert!(config.validation_enabled);
293        assert!(config.complexity_enabled);
294        assert!(!config.summarization_enabled);
295    }
296
297    #[test]
298    fn test_config_tier2_enabled() {
299        let config = LocalInferenceConfig::tier2_enabled();
300        assert!(!config.routing_enabled);
301        assert!(config.summarization_enabled);
302        assert!(config.retrieval_gating_enabled);
303        assert!(config.relevance_scoring_enabled);
304        assert!(config.strategy_selection_enabled);
305        assert!(config.entity_enhancement_enabled);
306    }
307
308    #[test]
309    fn test_config_all_enabled() {
310        let config = LocalInferenceConfig::all_enabled();
311        assert!(config.routing_enabled);
312        assert!(config.validation_enabled);
313        assert!(config.complexity_enabled);
314        assert!(config.summarization_enabled);
315        assert!(config.retrieval_gating_enabled);
316        assert!(config.relevance_scoring_enabled);
317        assert!(config.strategy_selection_enabled);
318        assert!(config.entity_enhancement_enabled);
319    }
320
321    #[test]
322    fn test_config_summarization_only() {
323        let config = LocalInferenceConfig::summarization_only();
324        assert!(!config.routing_enabled);
325        assert!(config.summarization_enabled);
326        assert_eq!(config.summarization_model, Some("lfm2-1.2b".to_string()));
327    }
328
329    #[test]
330    fn test_inference_timer() {
331        let timer = InferenceTimer::new("test_task", "test_model");
332        std::thread::sleep(std::time::Duration::from_millis(10));
333        assert!(timer.elapsed_ms() >= 10);
334    }
335}