use crate::engine::FeatureVector;
use burn::{
nn::{Linear, LinearConfig, Relu},
prelude::*,
tensor::{Tensor, backend::Backend},
};
use converge_core::{AgentEffect, ContextKey, ProposedFact, Suggestor};
use serde_json;
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
activation: Relu,
}
impl<B: Backend> Model<B> {
pub fn new(device: &B::Device) -> Self {
let config = ModelConfig::new(3, 16, 1);
config.init(device)
}
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.fc1.forward(input);
let x = self.activation.forward(x);
self.fc2.forward(x)
}
}
#[derive(Config, Debug)]
pub struct ModelConfig {
input_size: usize,
hidden_size: usize,
output_size: usize,
}
impl ModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
Model {
fc1: LinearConfig::new(self.input_size, self.hidden_size).init(device),
fc2: LinearConfig::new(self.hidden_size, self.output_size).init(device),
activation: Relu::new(),
}
}
}
#[derive(Debug)]
pub struct InferenceAgent {
}
impl InferenceAgent {
pub fn new() -> Self {
Self {}
}
}
#[async_trait::async_trait]
impl Suggestor for InferenceAgent {
fn name(&self) -> &str {
"InferenceAgent (Burn)"
}
fn dependencies(&self) -> &[ContextKey] {
&[ContextKey::Proposals]
}
fn accepts(&self, ctx: &dyn converge_core::ContextView) -> bool {
ctx.has(ContextKey::Proposals) && !ctx.has(ContextKey::Hypotheses)
}
async fn execute(&self, ctx: &dyn converge_core::ContextView) -> AgentEffect {
let _proposals = ctx.get(ContextKey::Proposals);
let facts = ctx.get(ContextKey::Proposals);
if facts.is_empty() {
return AgentEffect::empty();
}
let fact_content = &facts[0].content;
let features: FeatureVector = match serde_json::from_str(fact_content) {
Ok(f) => f,
Err(_) => return AgentEffect::empty(),
};
type B = burn::backend::NdArray;
let device = Default::default();
let model: Model<B> = ModelConfig::new(3, 16, 1).init(&device);
let input = Tensor::<B, 1>::from_floats(features.data.as_slice(), &device)
.reshape([features.shape[0], features.shape[1]]);
let output = model.forward(input);
let values: Vec<f32> = output.into_data().to_vec::<f32>().unwrap_or_default();
let prediction = values[0];
let hypo_content = format!("Prediction: {:.4} (based on {})", prediction, facts[0].id);
let hypothesis = ProposedFact::new(
ContextKey::Hypotheses,
format!("hypo-{}", facts[0].id),
hypo_content,
self.name(),
);
AgentEffect::with_proposal(hypothesis)
}
}
pub fn run_batch_inference(
config: &ModelConfig,
features: &FeatureVector,
) -> anyhow::Result<Vec<f32>> {
type B = burn::backend::NdArray;
let device = Default::default();
let model: Model<B> = config.init(&device);
let n = features.rows();
let input = Tensor::<B, 1>::from_floats(features.data.as_slice(), &device)
.reshape([n, config.input_size]);
let output = model.forward(input);
let values: Vec<f32> = output.into_data().to_vec::<f32>().unwrap_or_default();
Ok(values)
}