oris-runtime 0.61.0

An agentic workflow runtime and programmable AI execution system in Rust: stateful graphs, agents, tools, and multi-step execution.
use std::collections::HashMap;

use async_trait::async_trait;
use serde_json::Value;
use uuid::Uuid;
use weaviate_community::collections::objects::{Object, ObjectBuilder};
use weaviate_community::collections::query::GetBuilder;
use weaviate_community::WeaviateClient;

use crate::{
    embedding::embedder_trait::Embedder,
    schemas::Document,
    vectorstore::{VecStoreOptions, VectorStore, VectorStoreError},
};

pub struct Store {
    pub client: WeaviateClient,
    pub class_name: String,
    pub embedder: std::sync::Arc<dyn Embedder>,
}

pub type WeaviateOptions = VecStoreOptions<Value>;

#[async_trait]
impl VectorStore for Store {
    type Options = WeaviateOptions;

    async fn add_documents(
        &self,
        docs: &[Document],
        opt: &WeaviateOptions,
    ) -> Result<Vec<String>, VectorStoreError> {
        let _ = opt;
        let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);
        let texts: Vec<String> = docs.iter().map(|d| d.page_content.clone()).collect();
        let vectors = embedder.embed_documents(&texts).await?;
        if vectors.len() != docs.len() {
            return Err(VectorStoreError::InternalError(
                "Number of vectors and documents do not match".to_string(),
            ));
        }
        let mut ids = Vec::with_capacity(docs.len());
        for (doc, vec_f64) in docs.iter().zip(vectors.into_iter()) {
            let id = Uuid::new_v4();
            ids.push(id.to_string());
            let metadata_json =
                serde_json::to_string(&doc.metadata).unwrap_or_else(|_| "{}".to_string());
            let properties = serde_json::json!({
                "content": doc.page_content,
                "metadata": metadata_json,
            });
            let obj: Object = ObjectBuilder::new(&self.class_name, properties)
                .with_id(id)
                .with_vector(vec_f64)
                .build();
            self.client
                .objects
                .create(&obj, None)
                .await
                .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
        }
        Ok(ids)
    }

    async fn similarity_search(
        &self,
        query: &str,
        limit: usize,
        opt: &WeaviateOptions,
    ) -> Result<Vec<Document>, VectorStoreError> {
        let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);
        let qv = embedder.embed_query(query).await?;
        let near_vector_str = serde_json::to_string(&serde_json::json!({ "vector": qv }))?;
        let mut get = GetBuilder::new(&self.class_name, vec!["content", "metadata"])
            .with_limit(limit as u32)
            .with_near_vector(&near_vector_str)
            .with_additional(vec!["distance", "certainty"]);
        if let Some(ref w) = opt.filters {
            if let Some(s) = w.as_str() {
                get = get.with_where(s);
            }
        }
        let query_result = get.build();
        let raw = self
            .client
            .query
            .get(query_result)
            .await
            .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
        let score_threshold = opt
            .score_threshold
            .map(f64::from)
            .unwrap_or(f64::NEG_INFINITY);
        let docs = parse_get_response(&raw, &self.class_name, score_threshold)?;
        Ok(docs)
    }

    async fn delete(&self, ids: &[String], _opt: &WeaviateOptions) -> Result<(), VectorStoreError> {
        if ids.is_empty() {
            return Ok(());
        }
        for id in ids {
            let uuid = Uuid::parse_str(id).map_err(|e| {
                VectorStoreError::InvalidParameter(format!("invalid uuid {}: {}", id, e))
            })?;
            self.client
                .objects
                .delete(&self.class_name, &uuid, None, None)
                .await
                .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
        }
        Ok(())
    }
}

fn parse_get_response(
    raw: &Value,
    class_name: &str,
    score_threshold: f64,
) -> Result<Vec<Document>, VectorStoreError> {
    let get = raw
        .get("data")
        .and_then(|d| d.get("Get"))
        .and_then(|g| g.get(class_name))
        .and_then(|c| c.as_array())
        .ok_or_else(|| {
            VectorStoreError::DeserializationError(format!(
                "expected data.Get.{} in Weaviate response",
                class_name
            ))
        })?;
    let mut out = Vec::new();
    for obj in get {
        let content = obj
            .get("content")
            .and_then(|c| c.as_str())
            .unwrap_or("")
            .to_string();
        let metadata: HashMap<String, Value> = obj
            .get("metadata")
            .and_then(|m| serde_json::from_str(m.as_str().unwrap_or("{}")).ok())
            .unwrap_or_default();
        let score = obj
            .get("_additional")
            .and_then(|a| {
                a.get("certainty")
                    .and_then(|c| c.as_f64())
                    .or_else(|| a.get("distance").and_then(|d| d.as_f64()).map(|d| 1.0 - d))
            })
            .unwrap_or(0.0);
        if score >= score_threshold {
            out.push(Document {
                page_content: content,
                metadata,
                score,
            });
        }
    }
    Ok(out)
}