use arrow::record_batch::RecordBatch;
use arrow_ipc::reader::StreamReader;
use arrow_ipc::writer::StreamWriter;
use arrow_schema::SchemaRef;
use tonic::Status;
pub mod proto;
pub const FILE_DESCRIPTOR_SET: &[u8] =
include_bytes!(concat!(env!("OUT_DIR"), "/jammi_descriptor.bin"));
pub mod eval;
pub mod fine_tune;
pub mod request;
mod transport;
mod audit;
mod catalog;
mod channel;
mod embedding;
mod error;
mod eval_wire;
mod inference;
mod mutable_table;
mod training;
mod trigger;
pub use transport::{SessionChannel, SessionHeader, SessionTransport, SESSION_HEADER};
pub use audit::{parse_query_id, record_from_wire};
pub use catalog::{
derives_from_edge_from_proto, derives_from_edge_to_proto, match_verdict_from_proto,
match_verdict_to_proto, model_from_proto, model_to_proto, source_descriptor_from_proto,
source_type_from_proto, source_type_to_proto, staleness_from_proto, staleness_to_proto,
topic_from_proto, topic_to_proto,
};
pub use channel::{
channel_from_proto, channel_to_proto, columns_from_proto, columns_to_proto, parse_channel_id,
};
pub use embedding::{result_table_from_proto, result_table_with_outcome, ProtoQueryInput};
pub use error::{
attach_audit_detail, attach_error_detail, attach_trigger_detail, audit_error_from_status,
error_from_status, trigger_error_from_status,
};
pub use eval_wire::{
calibration_shape_from_proto, calibration_shape_to_proto, cohorts_from_proto, cohorts_to_proto,
eval_task_to_proto, EvalTaskFromWire,
};
pub use inference::infer_result_to_proto;
pub use mutable_table::{
definition_from_proto, definition_list_from_proto, definition_to_proto, parse_table_id,
};
pub use training::{config_to_proto, method_from_proto, method_to_proto};
pub use trigger::{
decode_publish_batch, decode_subscribed_batch, encode_delivered_batch, encode_publish_batch,
from_proto_timestamp, to_proto_timestamp,
};
pub fn model_task_from_proto(task: i32) -> Result<jammi_db::ModelTask, Status> {
use jammi_db::ModelTask;
use proto::inference::ModelTask as ProtoModelTask;
match ProtoModelTask::try_from(task) {
Ok(ProtoModelTask::TextEmbedding) => Ok(ModelTask::TextEmbedding),
Ok(ProtoModelTask::ImageEmbedding) => Ok(ModelTask::ImageEmbedding),
Ok(ProtoModelTask::AudioEmbedding) => Ok(ModelTask::AudioEmbedding),
Ok(ProtoModelTask::Classification) => Ok(ModelTask::Classification),
Ok(ProtoModelTask::Ner) => Ok(ModelTask::Ner),
Ok(ProtoModelTask::Regression) => Ok(ModelTask::Regression),
Ok(ProtoModelTask::Unspecified) | Err(_) => {
Err(Status::invalid_argument("task must be specified"))
}
}
}
pub fn model_task_to_proto(task: jammi_db::ModelTask) -> proto::inference::ModelTask {
use jammi_db::ModelTask;
use proto::inference::ModelTask as ProtoModelTask;
match task {
ModelTask::TextEmbedding => ProtoModelTask::TextEmbedding,
ModelTask::ImageEmbedding => ProtoModelTask::ImageEmbedding,
ModelTask::AudioEmbedding => ProtoModelTask::AudioEmbedding,
ModelTask::Classification => ProtoModelTask::Classification,
ModelTask::Ner => ProtoModelTask::Ner,
ModelTask::Regression => ProtoModelTask::Regression,
}
}
pub fn encode_ipc_stream(schema: &SchemaRef, batches: &[RecordBatch]) -> Result<Vec<u8>, Status> {
let mut buf: Vec<u8> = Vec::new();
{
let mut writer = StreamWriter::try_new(&mut buf, schema.as_ref())
.map_err(|e| Status::internal(format!("batch encode: {e}")))?;
for batch in batches {
writer
.write(batch)
.map_err(|e| Status::internal(format!("batch encode: {e}")))?;
}
writer
.finish()
.map_err(|e| Status::internal(format!("batch encode: {e}")))?;
}
Ok(buf)
}
pub fn decode_ipc_schema(bytes: &[u8]) -> Result<SchemaRef, Status> {
if bytes.is_empty() {
return Err(Status::invalid_argument("schema is required"));
}
let cursor = std::io::Cursor::new(bytes.to_vec());
let reader = StreamReader::try_new(cursor, None)
.map_err(|e| Status::invalid_argument(format!("schema decode: {e}")))?;
Ok(reader.schema())
}
pub fn decode_ipc_stream(data_header: &[u8], data_body: &[u8]) -> Result<Vec<RecordBatch>, Status> {
if data_header.is_empty() && data_body.is_empty() {
return Ok(Vec::new());
}
let mut bytes = Vec::with_capacity(data_header.len() + data_body.len());
bytes.extend_from_slice(data_header);
bytes.extend_from_slice(data_body);
let cursor = std::io::Cursor::new(bytes);
let reader = StreamReader::try_new(cursor, None)
.map_err(|e| Status::invalid_argument(format!("batch decode: {e}")))?;
reader
.collect::<Result<Vec<_>, _>>()
.map_err(|e| Status::invalid_argument(format!("batch decode: {e}")))
}