#![warn(missing_docs)]
pub mod error;
#[allow(clippy::all, missing_docs)]
pub mod inference;
pub mod model;
pub mod registry;
pub mod server;
pub mod system;
#[cfg(test)]
mod tests;
use std::fmt::Debug;
use crate::inference::grpc_inference_service_client::GrpcInferenceServiceClient;
use crate::inference::model_infer_request::{InferInputTensor, InferRequestedOutputTensor};
use crate::inference::{InferParameter, ModelInferRequest, ModelInferResponse};
use std::collections::HashMap;
use tonic::transport::Channel;
use tonic::Status;
use tracing::{event, instrument, Level};
pub struct TritonClient {
client: GrpcInferenceServiceClient<Channel>,
addr: String,
}
pub trait TritonModel {
fn build_inference_request(&self, raw_inputs: Vec<Vec<u8>>) -> ModelInferRequest;
fn display_output(&self, response: &ModelInferResponse);
}
#[derive(Clone, Debug)]
pub struct TritonModelBase {
pub name: String,
pub version: String,
pub param: HashMap<String, InferParameter>,
pub inputs: Vec<InferInputTensor>,
pub outputs: Vec<InferRequestedOutputTensor>,
}
impl TritonModelBase {
pub fn new(
name: String,
version: String,
inputs: Vec<InferInputTensor>,
outputs: Vec<InferRequestedOutputTensor>,
param: HashMap<String, InferParameter>,
) -> Self {
Self {
name,
version,
param,
inputs,
outputs,
}
}
}
impl Debug for TritonClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TritonClient")
.field("addr", &self.addr)
.finish()
}
}
impl TritonClient {
#[instrument]
pub async fn new(addr: &str) -> TritonClient {
event!(Level::TRACE, addr);
let mut client: Option<GrpcInferenceServiceClient<Channel>> = None;
while client.is_none() {
client = match Self::build_client(addr.to_owned()).await {
Ok(c) => Some(c),
Err(e) => {
event!(Level::WARN, error = ?e);
tokio::time::sleep(core::time::Duration::new(3, 0)).await;
None
}
};
}
event!(Level::INFO, "Connected to Triton client at {}", addr);
Self {
client: client.unwrap(),
addr: addr.to_owned(),
}
}
async fn build_client(
addr: String,
) -> Result<GrpcInferenceServiceClient<Channel>, tonic::transport::Error> {
GrpcInferenceServiceClient::connect(addr).await
}
pub fn client(&self) -> GrpcInferenceServiceClient<Channel> {
self.client.clone()
}
#[instrument(skip(inference_request))]
pub async fn submit_inference_request(
&self,
inference_request: ModelInferRequest,
) -> Result<ModelInferResponse, Status> {
event!(Level::TRACE, model_name = %inference_request.model_name,
model_version = %inference_request.model_version,
request_id = %inference_request.id,
);
let response = self
.client()
.model_infer(inference_request)
.await?
.into_inner();
Ok(response)
}
pub async fn model_names(&self) -> Vec<String> {
unimplemented!();
}
}