Skip to main content

prism/
model.rs

1// Copyright 2024-2026 Reflective Labs
2
3use crate::engine::FeatureVector;
4use crate::provenance::PRISM_PROVENANCE;
5use burn::{
6    nn::{Linear, LinearConfig, Relu},
7    prelude::*,
8    tensor::{Tensor, backend::Backend},
9};
10use converge_pack::{AgentEffect, Context, ContextKey, ProvenanceSource, Suggestor, TextPayload};
11
12// Re-defining for now if not public in engine, strictly we should move to lib or common
13// But for this example we assume we can deserialize into this struct.
14
15/// Simple MLP Model
16#[derive(Module, Debug)]
17pub struct Model<B: Backend> {
18    fc1: Linear<B>,
19    fc2: Linear<B>,
20    activation: Relu,
21}
22
23impl<B: Backend> Model<B> {
24    pub fn new(device: &B::Device) -> Self {
25        // Initialize with default config for demo
26        let config = ModelConfig::new(3, 16, 1);
27        config.init(device)
28    }
29
30    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
31        let x = self.fc1.forward(input);
32        let x = self.activation.forward(x);
33        self.fc2.forward(x)
34    }
35}
36
37#[derive(Config, Debug)]
38pub struct ModelConfig {
39    input_size: usize,
40    hidden_size: usize,
41    output_size: usize,
42}
43
44impl ModelConfig {
45    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
46        Model {
47            fc1: LinearConfig::new(self.input_size, self.hidden_size).init(device),
48            fc2: LinearConfig::new(self.hidden_size, self.output_size).init(device),
49            activation: Relu::new(),
50        }
51    }
52}
53
54#[derive(Debug, Default)]
55pub struct InferenceAgent {
56    // in real app, model might be Arc<Mutex<Model>> or just loaded
57    // For demo we instantiate on fly or would hold it.
58    // Burn models are cheap to clone if weights are Arc.
59    // For this demo, we won't hold the model in the struct to avoid generic complexity in the Suggestor trait object,
60    // or we use a concrete backend like NdArrayBackend.
61}
62
63impl InferenceAgent {
64    pub fn new() -> Self {
65        Self {}
66    }
67}
68
69#[async_trait::async_trait]
70impl Suggestor for InferenceAgent {
71    fn name(&self) -> &'static str {
72        "InferenceAgent (Burn)"
73    }
74
75    fn dependencies(&self) -> &[ContextKey] {
76        &[ContextKey::Proposals]
77    }
78
79    fn accepts(&self, ctx: &dyn Context) -> bool {
80        // Run if there are proposals (features) but no hypothesis yet
81        ctx.has(ContextKey::Proposals) && !ctx.has(ContextKey::Hypotheses)
82    }
83
84    fn provenance(&self) -> &'static str {
85        PRISM_PROVENANCE.as_str()
86    }
87
88    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
89        // 1. Find the feature proposal
90        // In reality, filtered by typed Prism provenance plus feature metadata.
91        let _proposals = ctx.get(ContextKey::Proposals); // wait, ctx.get returns Fact, but proposals are ProposedFacts?
92        // Ah, ctx.get(ContextKey) returns FACTs (promoted).
93        // If FeatureAgent emits PROPOSALS, they are in `ContextKey::Proposals`?
94        // Wait, ContextKey::Proposals is a key where Validated Proposals might live?
95        // OR does FeatureAgent emit *Facts* directly if trusted?
96
97        // In the `engine.rs` implementation I sent `ProposedFact` with key `ContextKey::Proposals`.
98        // If they are not promoted to Facts, they are not in `ctx.get()`.
99        // `Context` only stores `facts`.
100        // Proposals usually sit in a queue in the Engine or are added to Context if Key::Proposals is a storage for them?
101        // Looking at `ContextKey` definition: "Internal storage for proposed facts before validation."
102        // So they ARE stored as FACTS under the key `Proposals` if the system works that way?
103        // OR `ProposedFact`s are converted to `Fact`s by the engine.
104        // `ProposedFact::try_from` converts to `Fact`.
105        // If the engine accepts the proposal, it adds it as a Fact.
106
107        // Let's assume the engine validated it and stored it.
108        // So we look for Facts in `ContextKey::Proposals`?
109        // Actually, normally `Proposals` key is for... proposals.
110        // But `FeatureAgent` intended to propose `context.key = Proposals`?
111        // No, `FeatureAgent` sent `proposal.key = Proposals`.
112
113        // Let's assume we find the features in `ContextKey::Proposals` (as stored Facts).
114
115        // We iterate and find one we haven't processed? For now just take the first.
116
117        // This logic is simplified for demo.
118
119        let facts = ctx.get(ContextKey::Proposals);
120        if facts.is_empty() {
121            return AgentEffect::empty();
122        }
123
124        // 2. Read typed features
125        let features = match facts[0].payload::<FeatureVector>() {
126            Some(features) => features,
127            None => return AgentEffect::empty(),
128        };
129
130        // 3. Run Inference (Burn)
131        type B = burn::backend::NdArray;
132        let device = Default::default();
133        let model: Model<B> = ModelConfig::new(3, 16, 1).init(&device);
134
135        let input = Tensor::<B, 1>::from_floats(features.data.as_slice(), &device)
136            .reshape([features.shape[0], features.shape[1]]);
137
138        let output = model.forward(input);
139
140        // 4. Emit Hypothesis
141        let values: Vec<f32> = output.into_data().to_vec::<f32>().unwrap_or_default();
142        let prediction = values[0]; // Assume single output
143
144        let hypo_content = format!("Prediction: {:.4} (based on {})", prediction, facts[0].id());
145
146        let hypothesis = PRISM_PROVENANCE.proposed_fact(
147            ContextKey::Hypotheses,
148            format!("hypo-{}", facts[0].id()),
149            TextPayload::new(hypo_content),
150        );
151
152        AgentEffect::with_proposal(hypothesis)
153    }
154}
155
156/// Run batch inference on a [`FeatureVector`] using a configured model.
157///
158/// Abstracts Burn internals: the caller provides a [`ModelConfig`] and
159/// a [`FeatureVector`] (shape [n, input_size]), and receives a `Vec<f32>`
160/// of per-sample predictions.
161///
162/// Uses the `NdArray` backend internally.
163pub fn run_batch_inference(
164    config: &ModelConfig,
165    features: &FeatureVector,
166) -> anyhow::Result<Vec<f32>> {
167    type B = burn::backend::NdArray;
168    let device = Default::default();
169    let model: Model<B> = config.init(&device);
170
171    let n = features.rows();
172    let input = Tensor::<B, 1>::from_floats(features.data.as_slice(), &device)
173        .reshape([n, config.input_size]);
174    let output = model.forward(input);
175    let values: Vec<f32> = output.into_data().to_vec::<f32>().unwrap_or_default();
176    Ok(values)
177}