use jammi_db::catalog::result_repo::ResultTableRecord;
use tonic::Status;
use crate::proto::embedding as pb;
use crate::request::{Modality, QueryInput};
impl TryFrom<pb::Modality> for Modality {
type Error = Status;
fn try_from(modality: pb::Modality) -> Result<Self, Self::Error> {
match modality {
pb::Modality::Text => Ok(Modality::Text),
pb::Modality::Image => Ok(Modality::Image),
pb::Modality::Audio => Ok(Modality::Audio),
pb::Modality::Unspecified => {
Err(Status::invalid_argument("modality must be specified"))
}
}
}
}
impl TryFrom<i32> for Modality {
type Error = Status;
fn try_from(modality: i32) -> Result<Self, Self::Error> {
match pb::Modality::try_from(modality) {
Ok(m) => Modality::try_from(m),
Err(_) => Err(Status::invalid_argument("modality must be specified")),
}
}
}
pub struct ProtoQueryInput {
pub input: Option<pb::encode_query_request::Input>,
pub modality: Modality,
}
impl TryFrom<ProtoQueryInput> for QueryInput {
type Error = Status;
fn try_from(value: ProtoQueryInput) -> Result<Self, Self::Error> {
use pb::encode_query_request::Input as ProtoInput;
let input = value
.input
.ok_or_else(|| Status::invalid_argument("input (text or data) is required"))?;
match (value.modality, input) {
(Modality::Text, ProtoInput::Text(text)) => {
if text.is_empty() {
return Err(Status::invalid_argument("text is required"));
}
Ok(QueryInput::Text(text))
}
(Modality::Image | Modality::Audio, ProtoInput::Data(data)) => {
if data.is_empty() {
return Err(Status::invalid_argument("data is required"));
}
Ok(QueryInput::Bytes(data))
}
(Modality::Text, ProtoInput::Data(_)) => Err(Status::invalid_argument(
"TEXT modality requires text input, got data",
)),
(Modality::Image | Modality::Audio, ProtoInput::Text(_)) => Err(
Status::invalid_argument("IMAGE/AUDIO modality requires data input, got text"),
),
}
}
}
impl From<ResultTableRecord> for pb::ResultTable {
fn from(record: ResultTableRecord) -> Self {
pb::ResultTable {
table_name: record.table_name,
source_id: record.source_id,
model_id: record.model_id,
dimensions: record.dimensions.unwrap_or(0),
row_count: record.row_count as u64,
status: record.status,
task: super::model_task_to_proto(record.task) as i32,
cache_outcome: crate::proto::inference::CacheOutcome::Unspecified as i32,
}
}
}
pub fn result_table_with_outcome(record: ResultTableRecord, outcome: i32) -> pb::ResultTable {
pb::ResultTable {
cache_outcome: outcome,
..pb::ResultTable::from(record)
}
}
pub fn result_table_from_proto(table: pb::ResultTable) -> Result<ResultTableRecord, Status> {
let task = super::model_task_from_proto(table.task)?;
Ok(ResultTableRecord {
table_name: table.table_name,
source_id: table.source_id,
model_id: table.model_id,
task,
kind: jammi_db::catalog::result_repo::ResultTableKind::Model,
derived_from: None,
parquet_path: String::new(),
index_path: None,
dimensions: (table.dimensions != 0).then_some(table.dimensions),
distance_metric: String::new(),
row_count: table.row_count as usize,
status: table.status,
key_column: None,
text_columns: None,
created_at: String::new(),
completed_at: None,
tenant_id: None,
definition_hash: None,
input_anchors_json: None,
})
}