converge_analytics/model.rs
1// Copyright 2024-2026 Reflective Labs
2
3use crate::engine::FeatureVector;
4use burn::{
5 nn::{Linear, LinearConfig, Relu},
6 prelude::*,
7 tensor::{Tensor, backend::Backend},
8};
9use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
10
11// Re-defining for now if not public in engine, strictly we should move to lib or common
12// But for this example we assume we can deserialize into this struct.
13
14/// Simple MLP Model
15#[derive(Module, Debug)]
16pub struct Model<B: Backend> {
17 fc1: Linear<B>,
18 fc2: Linear<B>,
19 activation: Relu,
20}
21
22impl<B: Backend> Model<B> {
23 pub fn new(device: &B::Device) -> Self {
24 // Initialize with default config for demo
25 let config = ModelConfig::new(3, 16, 1);
26 config.init(device)
27 }
28
29 pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
30 let x = self.fc1.forward(input);
31 let x = self.activation.forward(x);
32 self.fc2.forward(x)
33 }
34}
35
36#[derive(Config, Debug)]
37pub struct ModelConfig {
38 input_size: usize,
39 hidden_size: usize,
40 output_size: usize,
41}
42
43impl ModelConfig {
44 pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
45 Model {
46 fc1: LinearConfig::new(self.input_size, self.hidden_size).init(device),
47 fc2: LinearConfig::new(self.hidden_size, self.output_size).init(device),
48 activation: Relu::new(),
49 }
50 }
51}
52
53#[derive(Debug, Default)]
54pub struct InferenceAgent {
55 // in real app, model might be Arc<Mutex<Model>> or just loaded
56 // For demo we instantiate on fly or would hold it.
57 // Burn models are cheap to clone if weights are Arc.
58 // For this demo, we won't hold the model in the struct to avoid generic complexity in the Suggestor trait object,
59 // or we use a concrete backend like NdArrayBackend.
60}
61
62impl InferenceAgent {
63 pub fn new() -> Self {
64 Self {}
65 }
66}
67
68#[async_trait::async_trait]
69impl Suggestor for InferenceAgent {
70 fn name(&self) -> &str {
71 "InferenceAgent (Burn)"
72 }
73
74 fn dependencies(&self) -> &[ContextKey] {
75 &[ContextKey::Proposals]
76 }
77
78 fn accepts(&self, ctx: &dyn Context) -> bool {
79 // Run if there are proposals (features) but no hypothesis yet
80 ctx.has(ContextKey::Proposals) && !ctx.has(ContextKey::Hypotheses)
81 }
82
83 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
84 // 1. Find the feature proposal
85 // In reality, filtered by provenance "polars-engine"
86 let _proposals = ctx.get(ContextKey::Proposals); // wait, ctx.get returns Fact, but proposals are ProposedFacts?
87 // Ah, ctx.get(ContextKey) returns FACTs (promoted).
88 // If FeatureAgent emits PROPOSALS, they are in `ContextKey::Proposals`?
89 // Wait, ContextKey::Proposals is a key where Validated Proposals might live?
90 // OR does FeatureAgent emit *Facts* directly if trusted?
91
92 // In the `engine.rs` implementation I sent `ProposedFact` with key `ContextKey::Proposals`.
93 // If they are not promoted to Facts, they are not in `ctx.get()`.
94 // `Context` only stores `facts`.
95 // Proposals usually sit in a queue in the Engine or are added to Context if Key::Proposals is a storage for them?
96 // Looking at `ContextKey` definition: "Internal storage for proposed facts before validation."
97 // So they ARE stored as FACTS under the key `Proposals` if the system works that way?
98 // OR `ProposedFact`s are converted to `Fact`s by the engine.
99 // `ProposedFact::try_from` converts to `Fact`.
100 // If the engine accepts the proposal, it adds it as a Fact.
101
102 // Let's assume the engine validated it and stored it.
103 // So we look for Facts in `ContextKey::Proposals`?
104 // Actually, normally `Proposals` key is for... proposals.
105 // But `FeatureAgent` intended to propose `context.key = Proposals`?
106 // No, `FeatureAgent` sent `proposal.key = Proposals`.
107
108 // Let's assume we find the features in `ContextKey::Proposals` (as stored Facts).
109
110 // We iterate and find one we haven't processed? For now just take the first.
111
112 // This logic is simplified for demo.
113
114 let facts = ctx.get(ContextKey::Proposals);
115 if facts.is_empty() {
116 return AgentEffect::empty();
117 }
118
119 let fact_content = &facts[0].content;
120
121 // 2. Deserialize features
122 let features: FeatureVector = match serde_json::from_str(fact_content) {
123 Ok(f) => f,
124 Err(_) => return AgentEffect::empty(),
125 };
126
127 // 3. Run Inference (Burn)
128 type B = burn::backend::NdArray;
129 let device = Default::default();
130 let model: Model<B> = ModelConfig::new(3, 16, 1).init(&device);
131
132 let input = Tensor::<B, 1>::from_floats(features.data.as_slice(), &device)
133 .reshape([features.shape[0], features.shape[1]]);
134
135 let output = model.forward(input);
136
137 // 4. Emit Hypothesis
138 let values: Vec<f32> = output.into_data().to_vec::<f32>().unwrap_or_default();
139 let prediction = values[0]; // Assume single output
140
141 let hypo_content = format!("Prediction: {:.4} (based on {})", prediction, facts[0].id);
142
143 let hypothesis = ProposedFact::new(
144 ContextKey::Hypotheses,
145 format!("hypo-{}", facts[0].id),
146 hypo_content,
147 self.name(),
148 );
149
150 AgentEffect::with_proposal(hypothesis)
151 }
152}
153
154/// Run batch inference on a [`FeatureVector`] using a configured model.
155///
156/// Abstracts Burn internals: the caller provides a [`ModelConfig`] and
157/// a [`FeatureVector`] (shape [n, input_size]), and receives a `Vec<f32>`
158/// of per-sample predictions.
159///
160/// Uses the `NdArray` backend internally.
161pub fn run_batch_inference(
162 config: &ModelConfig,
163 features: &FeatureVector,
164) -> anyhow::Result<Vec<f32>> {
165 type B = burn::backend::NdArray;
166 let device = Default::default();
167 let model: Model<B> = config.init(&device);
168
169 let n = features.rows();
170 let input = Tensor::<B, 1>::from_floats(features.data.as_slice(), &device)
171 .reshape([n, config.input_size]);
172 let output = model.forward(input);
173 let values: Vec<f32> = output.into_data().to_vec::<f32>().unwrap_or_default();
174 Ok(values)
175}