use std::{collections::HashSet, str::FromStr};
use anyhow::{Context, Result};
use serde::Deserialize;
use crate::{
runtime::ValorRuntimeEngine,
service::{
ValorCommonCmdService, ValorCommonService, ValorRighDeviceClassificationService,
ValorRighService, ValorService, ValorServiceId,
},
};
#[derive(Debug, Default, Clone)]
pub struct ValorWorkerServiceRegistry {
service_ids: HashSet<ValorServiceId>,
#[allow(unused)] services: Vec<ValorService>,
}
impl ValorWorkerServiceRegistry {
pub fn load_default() -> Result<Self> {
let path = std::env::var("VALOR_SERVICE_TOML")
.unwrap_or_else(|_| "valor/service.toml".to_string());
tracing::info!("Worker: loading services from {}", path);
let data = std::fs::read_to_string(&path)
.with_context(|| format!("Failed to read service config at {path}"))?;
let cfg: ValorServiceToml = toml::from_str(&data)?;
let mut service_ids: HashSet<ValorServiceId> = HashSet::new();
let mut services: Vec<ValorService> = Vec::new();
for entry in cfg
.service
.into_iter()
.filter(|e| e.enabled.unwrap_or(true))
{
let id = entry.id.clone();
service_ids.insert(id.clone());
match build_service_by_id(&id, &entry) {
Ok(Some(svc)) => services.push(svc),
Ok(None) => {
tracing::info!(
"Worker: service '{}' registered without instance (capability only)",
entry.id
);
}
Err(e) => {
tracing::warn!(
"Worker: failed to build service '{}' (skipping instance): {}",
entry.id,
e
);
}
}
}
Ok(Self {
service_ids,
services,
})
}
#[allow(unused)] pub fn services(&self) -> &[ValorService] {
&self.services
}
pub fn service_ids(&self) -> Vec<ValorServiceId> {
self.service_ids.iter().cloned().collect()
}
#[allow(unused)] pub fn contains(&self, id: &ValorServiceId) -> bool {
self.service_ids.contains(id)
}
}
#[derive(Debug, Deserialize)]
struct ValorServiceToml {
#[serde(default)]
service: Vec<ValorServiceTomlEntry>,
}
#[derive(Debug, Deserialize, Clone)]
struct ValorServiceTomlEntry {
id: ValorServiceId,
#[serde(default)]
enabled: Option<bool>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
runtime: Option<String>,
#[serde(default)]
version: Option<String>,
#[serde(default)]
model_path: Option<String>,
#[serde(default)]
metadata_file: Option<String>,
}
fn build_service_by_id(
id: &ValorServiceId,
entry: &ValorServiceTomlEntry,
) -> Result<Option<ValorService>> {
match id.as_ref() {
ValorServiceId::RIGH_DEVICE_CLASSIFICATION => {
let runtime = entry
.runtime
.as_deref()
.and_then(|s| ValorRuntimeEngine::from_str(s).ok())
.unwrap_or(ValorRuntimeEngine::TensorFlow);
let version = entry
.version
.as_deref()
.and_then(parse_version)
.unwrap_or_default();
let metadata: Option<Vec<u8>> = match &entry.metadata_file {
Some(p) => std::fs::read(p).ok(),
None => None,
};
let mut builder = ValorRighDeviceClassificationService::builder()
.runtime(runtime)
.version(version);
if let Some(bytes) = metadata.clone() {
builder = builder.metadata(bytes);
}
if let Some(path) = &entry.model_path {
builder = builder.model_path(path);
}
let svc = builder.build()?;
Ok(Some(ValorService::Righ(
ValorRighService::DeviceClassification(svc),
)))
}
ValorServiceId::COMMON_CMD => {
let runtime = entry
.runtime
.as_deref()
.and_then(|s| ValorRuntimeEngine::from_str(s).ok())
.unwrap_or(ValorRuntimeEngine::Cmd);
let version = entry
.version
.as_deref()
.and_then(parse_version)
.unwrap_or_else(|| righ_dm_rs::RighVersion::new(0, 1, 0));
let svc = ValorCommonCmdService::builder()
.runtime(runtime)
.version(version)
.command(entry.name.clone().unwrap_or_default())
.build()?;
Ok(Some(ValorService::Common(ValorCommonService::Cmd(svc))))
}
_ => Ok(None),
}
}
fn parse_version(s: &str) -> Option<righ_dm_rs::RighVersion> {
let parts: Vec<_> = s.split('.').collect();
if parts.len() < 3 {
return None;
}
let major = parts[0].parse::<u32>().ok()?;
let minor = parts[1].parse::<u32>().ok()?;
let patch = parts[2].parse::<u32>().ok()?;
Some(righ_dm_rs::RighVersion::new(major, minor, patch))
}