1use std::collections::{BTreeMap, HashMap};
25use std::pin::Pin;
26use std::sync::Arc;
27
28use arrow::array::{ArrayRef, Float32Array, RecordBatch, StringArray};
29use arrow::datatypes::{DataType, Field, Schema};
30use arrow_flight::sql::client::FlightSqlServiceClient;
31use futures::{Stream, StreamExt, TryStreamExt};
32use tonic::transport::Endpoint;
33
34use jammi_db::catalog::eval_repo::PerQueryEvalRecord;
35use jammi_db::catalog::result_repo::ResultTableRecord;
36use jammi_db::error::{JammiError, Result};
37use jammi_db::store::{CacheOutcome, CachePolicy};
38use jammi_db::trigger::{DeliveredBatch, Offset, Predicate, TopicDefinition, TriggerError};
39use jammi_db::{AuditError, ModelTask, PerQueryAudit, TenantId};
40
41use jammi_admin::CatalogClient;
42use jammi_wire::eval::{CompareEvalReport, EmbeddingEvalReport, EvalTask, InferenceEvalReport};
43use jammi_wire::fine_tune::{FineTuneConfig, FineTuneMethod};
44use jammi_wire::proto::audit::audit_service_client::AuditServiceClient;
45use jammi_wire::proto::audit::{
46 AuditFetchByQueryIdRequest, AuditFetchRecentRequest, AuditLogRequest,
47};
48use jammi_wire::proto::embedding::embedding_service_client::EmbeddingServiceClient;
49use jammi_wire::proto::embedding::{
50 encode_query_request::Input as ProtoEncodeInput, search_request::Query as ProtoSearchQuery,
51 EncodeQueryRequest, GenerateEmbeddingsRequest, QueryVector,
52 SearchRequest as ProtoSearchRequest, SearchResponse,
53};
54use jammi_wire::proto::eval as eval_pb;
55use jammi_wire::proto::eval::eval_service_client::EvalServiceClient;
56use jammi_wire::proto::inference::inference_service_client::InferenceServiceClient;
57use jammi_wire::proto::inference::InferRequest;
58use jammi_wire::proto::training::training_service_client::TrainingServiceClient;
59use jammi_wire::proto::training::{
60 start_training_request::Spec as ProtoTrainingSpec, FineTuneSpec, StartTrainingRequest,
61 TrainingStatusRequest,
62};
63use jammi_wire::proto::trigger::trigger_service_client::TriggerServiceClient;
64use jammi_wire::proto::trigger::{PublishRequest, SubscribeRequest, TopicName};
65use jammi_wire::request::{FineTuneJobId, Modality, QueryInput, SearchQuery, SearchRequest};
66use jammi_wire::{
67 audit_error_from_status, cohorts_to_proto, config_to_proto, decode_ipc_stream,
68 decode_subscribed_batch, encode_publish_batch, error_from_status, eval_task_to_proto,
69 method_to_proto, model_task_to_proto, record_from_wire, result_table_from_proto,
70 trigger_error_from_status, SessionChannel, SessionTransport, SESSION_HEADER,
71};
72
73#[derive(Clone)]
78pub struct DataClient {
79 transport: SessionTransport,
80 catalog: CatalogClient,
81}
82
83impl DataClient {
84 pub async fn connect(endpoint: impl Into<Endpoint>) -> Result<Self> {
88 let transport = SessionTransport::connect(endpoint).await?;
89 Ok(Self::over(transport))
90 }
91
92 pub fn over(transport: SessionTransport) -> Self {
95 let catalog = CatalogClient::over(transport.clone());
96 Self { transport, catalog }
97 }
98
99 pub fn catalog(&self) -> &CatalogClient {
102 &self.catalog
103 }
104
105 pub fn session_id(&self) -> &str {
108 self.transport.session_id()
109 }
110
111 fn embedding_client(&self) -> EmbeddingServiceClient<SessionChannel> {
112 self.transport
113 .service(EmbeddingServiceClient::with_interceptor)
114 }
115
116 fn inference_client(&self) -> InferenceServiceClient<SessionChannel> {
117 self.transport
118 .service(InferenceServiceClient::with_interceptor)
119 }
120
121 fn eval_client(&self) -> EvalServiceClient<SessionChannel> {
122 self.transport.service(EvalServiceClient::with_interceptor)
123 }
124
125 fn training_client(&self) -> TrainingServiceClient<SessionChannel> {
126 self.transport
127 .service(TrainingServiceClient::with_interceptor)
128 }
129
130 fn trigger_client(&self) -> TriggerServiceClient<SessionChannel> {
131 self.transport
132 .service(TriggerServiceClient::with_interceptor)
133 }
134
135 fn audit_client(&self) -> AuditServiceClient<SessionChannel> {
136 self.transport.service(AuditServiceClient::with_interceptor)
137 }
138
139 pub async fn bind_tenant(&self, t: TenantId) -> Result<()> {
143 self.catalog.bind_tenant(t).await
144 }
145
146 pub async fn unbind_tenant(&self) -> Result<()> {
148 self.catalog.unbind_tenant().await
149 }
150
151 pub async fn tenant(&self) -> Result<Option<TenantId>> {
153 self.catalog.tenant().await
154 }
155
156 pub async fn sql(&self, query: &str) -> Result<Vec<RecordBatch>> {
171 let mut client = FlightSqlServiceClient::new(self.transport.channel());
172 client.set_header(SESSION_HEADER, self.session_id().to_string());
173 let info = client
174 .execute(query.to_string(), None)
175 .await
176 .map_err(|e| JammiError::Other(format!("flight sql execute: {e}")))?;
177
178 let mut batches = Vec::new();
179 for endpoint in info.endpoint {
180 let ticket = endpoint
181 .ticket
182 .ok_or_else(|| JammiError::Other("flight sql endpoint carried no ticket".into()))?;
183 let stream = client
184 .do_get(ticket)
185 .await
186 .map_err(|e| JammiError::Other(format!("flight sql do_get: {e}")))?;
187 let endpoint_batches: Vec<RecordBatch> = stream
188 .try_collect()
189 .await
190 .map_err(|e| JammiError::Other(format!("flight sql stream: {e}")))?;
191 batches.extend(endpoint_batches);
192 }
193 Ok(batches)
194 }
195
196 pub async fn generate_embeddings(
201 &self,
202 source_id: &str,
203 model_id: &str,
204 columns: &[String],
205 key_column: &str,
206 modality: Modality,
207 cache: CachePolicy,
208 ) -> Result<(ResultTableRecord, CacheOutcome)> {
209 let table = self
210 .embedding_client()
211 .generate_embeddings(GenerateEmbeddingsRequest {
212 source_id: source_id.to_string(),
213 model_id: model_id.to_string(),
214 columns: columns.to_vec(),
215 key_column: key_column.to_string(),
216 modality: proto_modality(modality) as i32,
217 cache: proto_cache_policy(cache) as i32,
218 })
219 .await
220 .map_err(|s| error_from_status(&s))?
221 .into_inner();
222 let outcome = cache_outcome_from_proto(table.cache_outcome, &table.table_name)
223 .map_err(|s| error_from_status(&s))?;
224 let record = result_table_from_proto(table).map_err(|s| error_from_status(&s))?;
225 Ok((record, outcome))
226 }
227
228 pub async fn encode_query(
230 &self,
231 model_id: &str,
232 input: QueryInput,
233 modality: Modality,
234 ) -> Result<Vec<f32>> {
235 let input = match input {
236 QueryInput::Text(text) => ProtoEncodeInput::Text(text),
237 QueryInput::Bytes(bytes) => ProtoEncodeInput::Data(bytes),
238 };
239 let resp = self
240 .embedding_client()
241 .encode_query(EncodeQueryRequest {
242 model_id: model_id.to_string(),
243 modality: proto_modality(modality) as i32,
244 input: Some(input),
245 })
246 .await
247 .map_err(|s| error_from_status(&s))?
248 .into_inner();
249 Ok(resp.embedding)
250 }
251
252 pub async fn search(&self, request: SearchRequest) -> Result<Vec<RecordBatch>> {
256 let SearchRequest {
257 source_id,
258 query,
259 k,
260 embedding_table,
261 filter,
262 select,
263 } = request;
264 let query = match query {
265 SearchQuery::Vector(values) => ProtoSearchQuery::QueryVector(QueryVector { values }),
266 SearchQuery::RowKey(key) => ProtoSearchQuery::RowKey(key),
267 };
268 let resp = self
269 .embedding_client()
270 .search(ProtoSearchRequest {
271 source_id,
272 query: Some(query),
273 k: k as u32,
274 embedding_table,
275 filter,
276 select: select.clone(),
277 })
278 .await
279 .map_err(|s| error_from_status(&s))?
280 .into_inner();
281 hits_to_batch(resp, &select)
282 }
283
284 pub async fn infer(
288 &self,
289 source_id: &str,
290 model_id: &str,
291 task: ModelTask,
292 content_columns: &[String],
293 key_column: &str,
294 cache: CachePolicy,
295 ) -> Result<(Vec<RecordBatch>, CacheOutcome)> {
296 let resp = self
297 .inference_client()
298 .infer(InferRequest {
299 source_id: source_id.to_string(),
300 model_id: model_id.to_string(),
301 task: model_task_to_proto(task) as i32,
302 columns: content_columns.to_vec(),
303 key_column: key_column.to_string(),
304 tenant_id: String::new(),
305 cache: proto_cache_policy(cache) as i32,
306 })
307 .await
308 .map_err(|s| error_from_status(&s))?
309 .into_inner();
310 let outcome =
314 cache_outcome_from_proto(resp.cache_outcome, "").map_err(|s| error_from_status(&s))?;
315 let batch = resp.result.unwrap_or_default();
316 let batches = decode_ipc_stream(&batch.data_header, &batch.data_body)
317 .map_err(|s| error_from_status(&s))?;
318 Ok((batches, outcome))
319 }
320
321 pub async fn fine_tune(
326 &self,
327 source: &str,
328 base_model: &str,
329 columns: &[String],
330 method: FineTuneMethod,
331 task: ModelTask,
332 config: Option<FineTuneConfig>,
333 ) -> Result<FineTuneJobId> {
334 let resp = self
339 .training_client()
340 .start_training(StartTrainingRequest {
341 spec: Some(ProtoTrainingSpec::FineTune(FineTuneSpec {
342 source: source.to_string(),
343 columns: columns.to_vec(),
344 method: method_to_proto(method) as i32,
345 task: model_task_to_proto(task) as i32,
346 })),
347 base_model: base_model.to_string(),
348 config: config.as_ref().map(config_to_proto),
349 })
350 .await
351 .map_err(|s| error_from_status(&s))?
352 .into_inner();
353 Ok(FineTuneJobId(resp.job_id))
354 }
355
356 pub async fn fine_tune_status(&self, id: &FineTuneJobId) -> Result<String> {
358 let resp = self
359 .training_client()
360 .training_status(TrainingStatusRequest {
361 job_id: id.0.clone(),
362 })
363 .await
364 .map_err(|s| error_from_status(&s))?
365 .into_inner();
366 Ok(resp.status)
367 }
368
369 pub async fn eval_embeddings(
373 &self,
374 source_id: &str,
375 embedding_table: Option<&str>,
376 golden_source: &str,
377 k: usize,
378 cohorts: &HashMap<String, BTreeMap<String, String>>,
379 ) -> Result<EmbeddingEvalReport> {
380 let resp = self
381 .eval_client()
382 .eval_embeddings(eval_pb::EvalEmbeddingsRequest {
383 source_id: source_id.to_string(),
384 embedding_table: embedding_table.unwrap_or_default().to_string(),
385 golden_source: golden_source.to_string(),
386 k: k as u32,
387 cohorts: cohorts_to_proto(cohorts),
388 tenant_id: String::new(),
389 })
390 .await
391 .map_err(|s| error_from_status(&s))?
392 .into_inner();
393 resp.try_into()
394 }
395
396 pub async fn eval_per_query(&self, eval_run_id: &str) -> Result<Vec<PerQueryEvalRecord>> {
398 let resp = self
399 .eval_client()
400 .eval_per_query(eval_pb::EvalPerQueryRequest {
401 eval_run_id: eval_run_id.to_string(),
402 tenant_id: String::new(),
403 })
404 .await
405 .map_err(|s| error_from_status(&s))?
406 .into_inner();
407 Ok(resp.records.into_iter().map(Into::into).collect())
408 }
409
410 pub async fn eval_inference(
412 &self,
413 model_id: &str,
414 source_id: &str,
415 columns: &[String],
416 task: EvalTask,
417 golden_source: &str,
418 label_column: &str,
419 ) -> Result<InferenceEvalReport> {
420 let resp = self
421 .eval_client()
422 .eval_inference(eval_pb::EvalInferenceRequest {
423 model_id: model_id.to_string(),
424 source_id: source_id.to_string(),
425 columns: columns.to_vec(),
426 task: eval_task_to_proto(task) as i32,
427 golden_source: golden_source.to_string(),
428 label_column: label_column.to_string(),
429 tenant_id: String::new(),
430 })
431 .await
432 .map_err(|s| error_from_status(&s))?
433 .into_inner();
434 resp.try_into()
435 }
436
437 pub async fn eval_compare(
439 &self,
440 embedding_tables: &[String],
441 source_id: &str,
442 golden_source: &str,
443 k: usize,
444 ) -> Result<CompareEvalReport> {
445 let resp = self
446 .eval_client()
447 .eval_compare(eval_pb::EvalCompareRequest {
448 embedding_tables: embedding_tables.to_vec(),
449 source_id: source_id.to_string(),
450 golden_source: golden_source.to_string(),
451 k: k as u32,
452 tenant_id: String::new(),
453 })
454 .await
455 .map_err(|s| error_from_status(&s))?
456 .into_inner();
457 resp.try_into()
458 }
459
460 pub async fn publish(
465 &self,
466 topic: &TopicDefinition,
467 batch: RecordBatch,
468 ) -> std::result::Result<Offset, TriggerError> {
469 let wire_batch = encode_publish_batch(&batch).map_err(|s| trigger_error_from_status(&s))?;
470 let resp = self
471 .trigger_client()
472 .publish(PublishRequest {
473 topic: Some(TopicName {
474 name: topic.name.clone(),
475 }),
476 batch: Some(wire_batch),
477 tenant_id: String::new(),
479 })
480 .await
481 .map_err(|s| trigger_error_from_status(&s))?
482 .into_inner();
483 let committed_at = resp
484 .committed_at
485 .as_ref()
486 .map(jammi_wire::from_proto_timestamp)
487 .transpose()
488 .map_err(|s| trigger_error_from_status(&s))?
489 .ok_or_else(|| TriggerError::Driver("publish response missing committed_at".into()))?;
490 Ok(Offset::new(resp.offset, committed_at))
491 }
492
493 pub async fn subscribe(
499 &self,
500 topic: &TopicDefinition,
501 predicate: Predicate,
502 from_offset: Option<Offset>,
503 replay_only: bool,
504 ) -> std::result::Result<
505 Pin<Box<dyn Stream<Item = std::result::Result<DeliveredBatch, TriggerError>> + Send>>,
506 TriggerError,
507 > {
508 let streaming = self
509 .trigger_client()
510 .subscribe(SubscribeRequest {
511 topic: Some(TopicName {
512 name: topic.name.clone(),
513 }),
514 predicate: predicate.source_sql().unwrap_or("").to_string(),
518 from_offset: from_offset.map(|o| o.value()),
519 tenant_id: String::new(),
520 replay_only,
521 })
522 .await
523 .map_err(|s| trigger_error_from_status(&s))?
524 .into_inner();
525 let mapped = streaming.map(|item| match item {
530 Ok(wire) => decode_subscribed_batch(wire).map_err(|s| trigger_error_from_status(&s)),
531 Err(status) => Err(trigger_error_from_status(&status)),
532 });
533 Ok(Box::pin(mapped))
534 }
535
536 pub async fn audit_log(
541 &self,
542 records: Vec<PerQueryAudit>,
543 ) -> std::result::Result<(), AuditError> {
544 self.audit_client()
545 .audit_log(AuditLogRequest {
546 records: records.into_iter().map(Into::into).collect(),
547 })
548 .await
549 .map_err(|s| audit_error_from_status(&s))?;
550 Ok(())
551 }
552
553 pub async fn audit_fetch_by_query_id(
555 &self,
556 query_id: uuid::Uuid,
557 ) -> std::result::Result<Option<PerQueryAudit>, AuditError> {
558 let resp = self
559 .audit_client()
560 .audit_fetch_by_query_id(AuditFetchByQueryIdRequest {
561 query_id: query_id.to_string(),
562 })
563 .await
564 .map_err(|s| audit_error_from_status(&s))?
565 .into_inner();
566 resp.record.map(record_from_wire).transpose()
567 }
568
569 pub async fn audit_fetch_recent(
571 &self,
572 limit: usize,
573 ) -> std::result::Result<Vec<PerQueryAudit>, AuditError> {
574 let resp = self
575 .audit_client()
576 .audit_fetch_recent(AuditFetchRecentRequest {
577 limit: limit as u32,
578 })
579 .await
580 .map_err(|s| audit_error_from_status(&s))?
581 .into_inner();
582 resp.records.into_iter().map(record_from_wire).collect()
583 }
584}
585
586fn proto_cache_policy(cache: CachePolicy) -> jammi_wire::proto::inference::CachePolicy {
588 use jammi_wire::proto::inference::CachePolicy as Pb;
589 match cache {
590 CachePolicy::Use => Pb::Use,
591 CachePolicy::Bypass => Pb::Bypass,
592 }
593}
594
595fn cache_outcome_from_proto(
611 outcome: i32,
612 reused_table: &str,
613) -> std::result::Result<CacheOutcome, tonic::Status> {
614 use jammi_wire::proto::inference::CacheOutcome as Pb;
615 match Pb::try_from(outcome) {
616 Ok(Pb::Computed) => Ok(CacheOutcome::Computed),
617 Ok(Pb::Reused) => Ok(CacheOutcome::Reused {
618 table: reused_table.to_string(),
619 }),
620 Ok(Pb::Unspecified) | Err(_) => Err(tonic::Status::internal(
621 "producer returned an unspecified cache outcome",
622 )),
623 }
624}
625
626fn proto_modality(modality: Modality) -> jammi_wire::proto::embedding::Modality {
630 use jammi_wire::proto::embedding::Modality as Pb;
631 match modality {
632 Modality::Text => Pb::Text,
633 Modality::Image => Pb::Image,
634 Modality::Audio => Pb::Audio,
635 }
636}
637
638fn hits_to_batch(resp: SearchResponse, select: &[String]) -> Result<Vec<RecordBatch>> {
646 if resp.hits.is_empty() {
647 return Ok(Vec::new());
648 }
649 let keys: Vec<&str> = resp.hits.iter().map(|h| h.key.as_str()).collect();
650 let scores: Vec<f32> = resp.hits.iter().map(|h| h.score).collect();
651
652 let mut fields: Vec<Field> = vec![
653 Field::new("_row_id", DataType::Utf8, false),
654 Field::new("similarity", DataType::Float32, false),
655 ];
656 let mut arrays: Vec<ArrayRef> = vec![
657 Arc::new(StringArray::from(keys)),
658 Arc::new(Float32Array::from(scores)),
659 ];
660 for name in select {
661 let values: Vec<String> = resp
662 .hits
663 .iter()
664 .map(|h| h.columns.get(name).cloned().unwrap_or_default())
665 .collect();
666 fields.push(Field::new(name, DataType::Utf8, false));
667 arrays.push(Arc::new(StringArray::from(values)));
668 }
669
670 let schema = Arc::new(Schema::new(fields));
671 let batch = RecordBatch::try_new(schema, arrays)
672 .map_err(|e| JammiError::Other(format!("rebuild search batch: {e}")))?;
673 Ok(vec![batch])
674}