use std::error::Error as StdError;
use std::sync::Arc;
use bytes::Bytes;
use mnem_core::id::NodeId;
use mnem_core::objects::{Dtype, Embedding, Node};
use mnem_core::repo::ReadonlyRepo;
use mnem_core::store::{Blockstore, MemoryBlockstore, MemoryOpHeadsStore, OpHeadsStore};
use crate::adapter::{BenchAdapter, Hit, IngestDoc};
use crate::embed::BenchEmbedder;
pub struct MnemAdapter {
repo: ReadonlyRepo,
embedder: BenchEmbedder,
id_to_external: std::collections::HashMap<NodeId, String>,
}
impl MnemAdapter {
pub fn new(dim: u32) -> Result<Self, Box<dyn StdError>> {
Self::with_embedder(BenchEmbedder::bag_of_tokens(dim))
}
pub fn with_embedder(embedder: BenchEmbedder) -> Result<Self, Box<dyn StdError>> {
let bs: Arc<dyn Blockstore> = Arc::new(MemoryBlockstore::default());
let ohs: Arc<dyn OpHeadsStore> = Arc::new(MemoryOpHeadsStore::default());
let repo = ReadonlyRepo::init(bs, ohs).map_err(|e| Box::new(e) as Box<dyn StdError>)?;
Ok(Self {
repo,
embedder,
id_to_external: std::collections::HashMap::new(),
})
}
pub fn embedder(&self) -> &BenchEmbedder {
&self.embedder
}
pub fn model_id(&self) -> &str {
self.embedder.model()
}
}
impl BenchAdapter for MnemAdapter {
fn reset(&mut self) -> Result<(), Box<dyn StdError>> {
let bs: Arc<dyn Blockstore> = Arc::new(MemoryBlockstore::default());
let ohs: Arc<dyn OpHeadsStore> = Arc::new(MemoryOpHeadsStore::default());
self.repo = ReadonlyRepo::init(bs, ohs).map_err(|e| Box::new(e) as Box<dyn StdError>)?;
self.id_to_external.clear();
Ok(())
}
fn ingest(&mut self, docs: &[IngestDoc]) -> Result<(), Box<dyn StdError>> {
if docs.is_empty() {
return Ok(());
}
let mut tx = self.repo.start_transaction();
for d in docs {
let id = NodeId::new_v7();
let mut node = Node::new(id, d.label.as_str()).with_summary(d.text.as_str());
node = node.with_prop("external_id", ipld_string(d.external_id.as_str()));
for (k, v) in &d.props {
if let Some(s) = v.as_str() {
node = node.with_prop(k.as_str(), ipld_string(s));
}
}
if d.text.len() < 1 << 16 {
node = node.with_content(Bytes::from(d.text.clone().into_bytes()));
}
let vec = self.embedder.embed_text(&d.text)?;
let emb = to_embedding_f32(self.embedder.model(), &vec);
let cid = tx
.add_node(&node)
.map_err(|e| Box::new(e) as Box<dyn StdError>)?;
tx.set_embedding(cid, self.embedder.model().to_string(), emb)
.map_err(|e| Box::new(e) as Box<dyn StdError>)?;
self.id_to_external.insert(id, d.external_id.clone());
}
let next = tx
.commit("mnem-bench", "bench ingest")
.map_err(|e| Box::new(e) as Box<dyn StdError>)?;
self.repo = next;
Ok(())
}
fn retrieve(
&mut self,
label: &str,
query: &str,
top_k: usize,
) -> Result<Vec<Hit>, Box<dyn StdError>> {
let qvec = self.embedder.embed_text(query)?;
let result = self
.repo
.retrieve()
.label(label)
.vector(self.embedder.model().to_string(), qvec)
.limit(top_k.max(1))
.execute()
.map_err(|e| Box::new(e) as Box<dyn StdError>)?;
let mut out = Vec::with_capacity(result.items.len());
for item in result.items {
let ext = if let Some(e) = self.id_to_external.get(&item.node.id) {
e.clone()
} else {
node_external_id(&item.node).unwrap_or_default()
};
out.push(Hit {
external_id: ext,
score: item.score,
});
}
Ok(out)
}
fn name(&self) -> &str {
"mnem"
}
}
fn ipld_string(s: &str) -> ipld_core::ipld::Ipld {
ipld_core::ipld::Ipld::String(s.to_string())
}
fn node_external_id(node: &Node) -> Option<String> {
let v = node.props.get("external_id")?;
match v {
ipld_core::ipld::Ipld::String(s) => Some(s.clone()),
_ => None,
}
}
fn to_embedding_f32(model: &str, v: &[f32]) -> Embedding {
let mut bytes = Vec::with_capacity(v.len() * 4);
for f in v {
bytes.extend_from_slice(&f.to_ne_bytes());
}
Embedding {
model: model.to_string(),
dtype: Dtype::F32,
dim: v.len() as u32,
vector: Bytes::from(bytes),
}
}