use std::collections::HashMap;
use tonic::{Request, Status};
use tracing::{event, instrument, Level};
use crate::inference::ModelRepositoryParameter;
use crate::inference::{RepositoryIndexRequest, RepositoryIndexResponse};
use crate::inference::{RepositoryModelLoadRequest, RepositoryModelLoadResponse};
use crate::inference::{RepositoryModelUnloadRequest, RepositoryModelUnloadResponse};
use crate::TritonClient;
impl TritonClient {
#[instrument]
pub async fn model_repository_index(&self) -> Result<RepositoryIndexResponse, Status> {
let message = RepositoryIndexRequest {
repository_name: "".to_owned(),
ready: true,
};
let request = Request::new(message);
match self.client().repository_index(request).await {
Ok(r) => {
event!(Level::INFO, model_repository_index = ?r);
Ok(r.into_inner())
}
Err(s) => {
event!(Level::ERROR, status = ?s);
Err(s)
}
}
}
#[instrument]
pub async fn load_model(
&self,
repository_name: Option<&str>,
model_name: &str,
parameters: Option<HashMap<String, ModelRepositoryParameter>>,
) -> Result<RepositoryModelLoadResponse, Status> {
let repository_name = repository_name.unwrap_or("").to_owned();
let model_name = model_name.to_owned();
let parameters = match parameters {
Some(p) => p,
None => HashMap::new(),
};
let request = RepositoryModelLoadRequest {
repository_name,
model_name,
parameters,
};
let request = Request::new(request);
match self.client().repository_model_load(request).await {
Ok(r) => {
event!(Level::INFO, "successfully loaded model {r:?}");
Ok(r.into_inner())
}
Err(s) => {
event!(Level::ERROR, "failed to load model {s}");
Err(s)
}
}
}
#[instrument]
pub async fn unload_model(
&self,
repository_name: Option<&str>,
model_name: &str,
parameters: Option<HashMap<String, ModelRepositoryParameter>>,
) -> Result<RepositoryModelUnloadResponse, Status> {
let repository_name = repository_name.unwrap_or("").to_owned();
let model_name = model_name.to_owned();
let parameters = match parameters {
Some(p) => p,
None => HashMap::new(),
};
let request = RepositoryModelUnloadRequest {
repository_name,
model_name,
parameters,
};
let request = Request::new(request);
match self.client().repository_model_unload(request).await {
Ok(r) => {
event!(Level::INFO, "successfully unloaded model, {r:?}");
Ok(r.into_inner())
}
Err(s) => {
event!(Level::ERROR, "failed to unload model, {s}");
Err(s)
}
}
}
}