oris-runtime 0.15.0

An agentic workflow runtime and programmable AI execution system in Rust: stateful graphs, agents, tools, and multi-step execution.
use std::{collections::HashMap, sync::Arc};

use async_trait::async_trait;
use pgvector::Vector;
use serde_json::{json, Value};
use sqlx::{Pool, Postgres, Row};
use uuid::Uuid;

use crate::{
    embedding::embedder_trait::Embedder,
    schemas::Document,
    vectorstore::{VecStoreOptions, VectorStore, VectorStoreError},
};

pub struct Store {
    pub(crate) embedder: Arc<dyn Embedder>,
    pub(crate) pool: Pool<Postgres>,
    pub(crate) collection_name: String,
    pub(crate) collection_table_name: String,
    pub(crate) collection_uuid: String,
    pub(crate) collection_metadata: HashMap<String, Value>,
    pub(crate) embedder_table_name: String,
    pub(crate) pre_delete_collection: bool,
    pub(crate) vector_dimensions: i32,
    pub(crate) hns_index: Option<HNSWIndex>,
    pub(crate) vstore_options: PgOptions,
}

#[derive(Debug, Clone, PartialEq)]
pub enum PgFilter {
    Eq(PgLit, PgLit),
    Cmp(std::cmp::Ordering, PgLit, PgLit),
    In(PgLit, Vec<String>),
    And(Vec<PgFilter>),
    Or(Vec<PgFilter>),
}

pub type Column = String;

pub type Path = Vec<String>;

#[derive(Debug, Clone, PartialEq)]
pub enum PgLit {
    JsonField(Path),
    LitStr(String),
    RawJson(Value),
}

impl ToString for PgLit {
    fn to_string(&self) -> String {
        match self {
            PgLit::LitStr(str) => format!("'{}'", str.clone()),
            PgLit::JsonField(path) => format!("cmetadata#>>'{{{}}}'", path.join(",")),
            PgLit::RawJson(value) => serde_json::to_string(value).unwrap_or("null".to_string()),
        }
    }
}

impl ToString for PgFilter {
    fn to_string(&self) -> String {
        match self {
            PgFilter::Eq(a, b) => format!("{} = {}", a.to_string(), b.to_string()),
            PgFilter::Cmp(ordering, a, b) => {
                let op = match ordering {
                    std::cmp::Ordering::Less => "<",
                    std::cmp::Ordering::Greater => ">",
                    std::cmp::Ordering::Equal => "=",
                };
                format!("{} {} {}", a.to_string(), op, b.to_string())
            }
            PgFilter::In(a, values) => {
                format!(
                    "{} = ANY(ARRAY[{}])",
                    a.to_string(),
                    values
                        .iter()
                        .map(|s| format!("'{}'", s))
                        .collect::<Vec<String>>()
                        .join(",")
                )
            }
            PgFilter::And(pgfilters) => pgfilters
                .iter()
                .map(|pgf| pgf.to_string())
                .collect::<Vec<String>>()
                .join(" AND "),
            PgFilter::Or(pgfilters) => pgfilters
                .iter()
                .map(|pgf| pgf.to_string())
                .collect::<Vec<String>>()
                .join(" OR "),
        }
    }
}

pub struct HNSWIndex {
    pub(crate) m: i32,
    pub(crate) ef_construction: i32,
    pub(crate) distance_function: String,
}

impl HNSWIndex {
    pub fn new(m: i32, ef_construction: i32, distance_function: &str) -> Self {
        HNSWIndex {
            m,
            ef_construction,
            distance_function: distance_function.into(),
        }
    }
}

impl Store {
    fn get_filters(&self, opt: &PgOptions) -> Result<String, VectorStoreError> {
        match &opt.filters {
            Some(pgfilter) => Ok(pgfilter.to_string()),
            None => Ok("TRUE".to_string()), // No filters provided
        }
    }

    fn get_name_space(&self, opt: &PgOptions) -> String {
        match &opt.name_space {
            Some(name_space) => name_space.clone(),
            None => self.collection_name.clone(),
        }
    }

    fn get_score_threshold(&self, opt: &PgOptions) -> Result<f32, VectorStoreError> {
        match &opt.score_threshold {
            Some(score_threshold) => {
                if *score_threshold < 0.0 || *score_threshold > 1.0 {
                    return Err("Invalid score threshold".into());
                }
                Ok(*score_threshold)
            }
            None => Ok(0.0),
        }
    }

    async fn drop_tables(&self) -> Result<(), VectorStoreError> {
        sqlx::query(&format!(
            r#"DROP TABLE IF EXISTS {}"#,
            self.embedder_table_name
        ))
        .execute(&self.pool)
        .await
        .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;

        sqlx::query(&format!(
            r#"DROP TABLE IF EXISTS {}"#,
            self.collection_table_name
        ))
        .execute(&self.pool)
        .await
        .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;

        Ok(())
    }

    async fn remove_collection(&self) -> Result<(), VectorStoreError> {
        sqlx::query(r#"DELETE FROM collection WHERE uuid = $1"#)
            .bind(&self.collection_uuid)
            .execute(&self.pool)
            .await
            .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
        Ok(())
    }
}

pub type PgOptions = VecStoreOptions<PgFilter>;

impl Default for PgOptions {
    fn default() -> Self {
        PgOptions {
            filters: None,
            score_threshold: None,
            name_space: None,
            embedder: None,
        }
    }
}

#[async_trait]
impl VectorStore for Store {
    type Options = PgOptions;

    async fn add_documents(
        &self,
        docs: &[Document],
        opt: &PgOptions,
    ) -> Result<Vec<String>, VectorStoreError> {
        if opt.score_threshold.is_some() || opt.filters.is_some() || opt.name_space.is_some() {
            return Err(VectorStoreError::InvalidParameter(
                "score_threshold, filters, and name_space are not supported in pgvector"
                    .to_string(),
            ));
        }
        let texts: Vec<String> = docs.iter().map(|d| d.page_content.clone()).collect();

        let embedder = opt.embedder.as_ref().unwrap_or(&self.embedder);

        let vectors = embedder.embed_documents(&texts).await?;

        if vectors.len() != docs.len() {
            return Err(VectorStoreError::InternalError(
                "Number of vectors and documents do not match".to_string(),
            ));
        }

        let mut tx = self
            .pool
            .begin()
            .await
            .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;

        let mut ids = Vec::with_capacity(docs.len());

        for (doc, vector) in docs.iter().zip(vectors.iter()) {
            let id = Uuid::new_v4().to_string();
            ids.push(id.clone());

            let vector_value =
                Vector::from(vector.into_iter().map(|x| *x as f32).collect::<Vec<f32>>());

            sqlx::query(&format!(
                r#"INSERT INTO {} 
(uuid, document, embedding, cmetadata, collection_id) VALUES ($1, $2, $3, $4, $5)"#,
                self.embedder_table_name
            ))
            .bind(&id)
            .bind(&doc.page_content)
            .bind(&vector_value)
            .bind(json!(&doc.metadata))
            .bind(&self.collection_uuid)
            .execute(&mut *tx)
            .await
            .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
        }

        tx.commit()
            .await
            .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;

        Ok(ids)
    }

    async fn similarity_search(
        &self,
        query: &str,
        limit: usize,
        opt: &PgOptions,
    ) -> Result<Vec<Document>, VectorStoreError> {
        let collection_name = self.get_name_space(opt);
        let where_filter = self.get_filters(opt)?;

        let sql = format!(
            r#"WITH filtered_embedding_dims AS MATERIALIZED (
                SELECT
                    *
                FROM
                    {}
                WHERE
                    vector_dims(embedding) = $1
            )
            SELECT
                data.document,
                data.cmetadata,
                data.distance
            FROM (
                SELECT
                    filtered_embedding_dims.*,
                    embedding <=> $2 AS distance
                FROM
                    filtered_embedding_dims
                    JOIN {} ON filtered_embedding_dims.collection_id = {}.uuid
                WHERE {}.name = '{}'
            ) AS data
            WHERE {}
            ORDER BY
                data.distance DESC
            LIMIT $3"#,
            self.embedder_table_name,
            self.collection_table_name,
            self.collection_table_name,
            self.collection_table_name,
            collection_name,
            where_filter,
        );

        let query_vector = self.embedder.embed_query(query).await?;

        let vector_dims = query_vector.len();

        let rows = sqlx::query(&sql)
            .bind(vector_dims as i64)
            .bind(&Vector::from(
                query_vector
                    .into_iter()
                    .map(|x| x as f32)
                    .collect::<Vec<f32>>(),
            ))
            .bind(limit as i32)
            .fetch_all(&self.pool)
            .await
            .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;

        let docs = rows
            .into_iter()
            .map(|row| {
                let page_content: String = row
                    .try_get(0)
                    .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
                let metadata_json: Value = row
                    .try_get(1)
                    .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
                let score: f64 = row
                    .try_get(2)
                    .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;

                let metadata = if let Value::Object(obj) = metadata_json {
                    obj.into_iter().collect()
                } else {
                    HashMap::new()
                };

                Ok(Document {
                    page_content,
                    metadata,
                    score,
                })
            })
            .collect::<Result<Vec<Document>, VectorStoreError>>()?;

        Ok(docs)
    }

    async fn delete(&self, ids: &[String], _opt: &PgOptions) -> Result<(), VectorStoreError> {
        if ids.is_empty() {
            return Ok(());
        }
        // pgvector uses uuid in the embedding table
        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("${}", i)).collect();
        let sql = format!(
            "DELETE FROM {} WHERE uuid IN ({})",
            self.embedder_table_name,
            placeholders.join(", ")
        );
        let mut query = sqlx::query(&sql);
        for id in ids {
            query = query.bind(id);
        }
        query
            .execute(&self.pool)
            .await
            .map_err(|e| VectorStoreError::Unknown(e.to_string()))?;
        Ok(())
    }
}