tritonserver_rs/
parameter.rs1use std::{fs::File, path::Path, ptr::null_mut};
2
3use crate::{error::Error, sys, to_cstring};
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
7#[repr(u32)]
8pub enum TritonParameterType {
9 String = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_STRING,
10 Int = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_INT,
11 Bool = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BOOL,
12 Double = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_DOUBLE,
13 Bytes = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BYTES,
14}
15
16#[derive(Debug, Clone)]
18pub enum ParameterContent {
19 String(String),
20 Int(i64),
21 Bool(bool),
22 Double(f64),
23 Bytes(Vec<u8>),
24}
25
26#[derive(Debug)]
28pub struct Parameter {
29 pub(crate) ptr: *mut sys::TRITONSERVER_Parameter,
30 pub name: String,
31 pub content: ParameterContent,
32}
33
34unsafe impl Send for Parameter {}
35
36impl Parameter {
37 pub fn new<N: AsRef<str>>(name: N, value: ParameterContent) -> Result<Self, Error> {
39 let c_name = to_cstring(&name)?;
40 let ptr = match &value {
41 ParameterContent::Bool(v) => unsafe {
42 sys::TRITONSERVER_ParameterNew(
43 c_name.as_ptr(),
44 TritonParameterType::Bool as _,
45 v as *const bool as *const _,
46 )
47 },
48 ParameterContent::Int(v) => unsafe {
49 sys::TRITONSERVER_ParameterNew(
50 c_name.as_ptr(),
51 TritonParameterType::Int as _,
52 v as *const i64 as *const _,
53 )
54 },
55 ParameterContent::String(v) => {
56 let v = to_cstring(v)?;
57 unsafe {
58 sys::TRITONSERVER_ParameterNew(
59 c_name.as_ptr(),
60 TritonParameterType::String as _,
61 v.as_ptr() as *const _,
62 )
63 }
64 }
65 ParameterContent::Double(v) => unsafe {
66 sys::TRITONSERVER_ParameterNew(
67 c_name.as_ptr(),
68 TritonParameterType::Double as _,
69 v as *const f64 as *const _,
70 )
71 },
72 ParameterContent::Bytes(v) => unsafe {
73 sys::TRITONSERVER_ParameterBytesNew(
74 c_name.as_ptr(),
75 v.as_ptr() as *const _,
76 v.len() as _,
77 )
78 },
79 };
80
81 Ok(Self {
82 ptr,
83 name: name.as_ref().to_string(),
84 content: value,
85 })
86 }
87
88 pub fn from_config_with_exact_version(
94 mut config: serde_json::Value,
95 version: i64,
96 ) -> Result<Self, Error> {
97 config["version_policy"] = serde_json::json!({"specific": { "versions": [version]}});
98 Parameter::new("config", ParameterContent::String(config.to_string()))
99 }
100}
101
102impl Clone for Parameter {
103 fn clone(&self) -> Self {
104 Parameter::new(self.name.clone(), self.content.clone()).unwrap_or_else(|err| {
105 log::warn!("Error cloning parameter: {err}. Result will be empty, do not use it.");
106 Parameter {
107 ptr: null_mut(),
108 name: String::new(),
109 content: ParameterContent::String(String::new()),
110 }
111 })
112 }
113}
114
115impl Drop for Parameter {
116 fn drop(&mut self) {
117 if !self.ptr.is_null() {
118 unsafe { sys::TRITONSERVER_ParameterDelete(self.ptr) }
119 }
120 }
121}
122
123fn hjson_to_json(value: serde_hjson::Value) -> serde_json::Value {
124 match value {
125 serde_hjson::Value::Null => serde_json::Value::Null,
126 serde_hjson::Value::U64(v) => serde_json::Value::from(v),
127 serde_hjson::Value::I64(v) => serde_json::Value::from(v),
128 serde_hjson::Value::F64(v) => serde_json::Value::from(v),
129 serde_hjson::Value::Bool(v) => serde_json::Value::from(v),
130 serde_hjson::Value::String(v) => serde_json::Value::from(v),
131
132 serde_hjson::Value::Array(v) => {
133 serde_json::Value::from_iter(v.into_iter().map(hjson_to_json))
134 }
135 serde_hjson::Value::Object(v) => serde_json::Value::from_iter(
136 v.into_iter()
137 .map(|(key, value)| (key, hjson_to_json(value))),
138 ),
139 }
140}
141
142pub fn load_config_as_json<P: AsRef<Path>>(config_path: P) -> Result<serde_json::Value, Error> {
149 let content = File::open(config_path).map_err(|err| {
150 Error::new(
151 crate::error::ErrorCode::InvalidArg,
152 format!("Error opening the config file: {err}"),
153 )
154 })?;
155 let value = serde_hjson::from_reader::<_, serde_hjson::Value>(&content).map_err(|err| {
156 Error::new(
157 crate::error::ErrorCode::InvalidArg,
158 format!("Error parsing the config file as hjson: {err}"),
159 )
160 })?;
161 Ok(hjson_to_json(value))
162}
163
164#[test]
165fn test_config_to_json() {
166 let json_cfg = serde_json::json!({
167 "name": "voicenet",
168 "platform": "onnxruntime_onnx",
169 "input": [
170 {
171 "data_type": "TYPE_FP32",
172 "name": "input",
173 "dims": [512, 160000]
174 }
175 ],
176 "output": [
177 {
178 "data_type": "TYPE_FP32",
179 "name": "output",
180 "dims": [512, 512]
181 }
182 ],
183 "instance_group": [
184 {
185 "count": 2,
186 "kind": "KIND_CPU"
187 }
188 ],
189 "optimization": { "execution_accelerators": {
190 "cpu_execution_accelerator" : [ {
191 "name" : "openvino"
192 } ]
193 }}
194 });
195
196 assert_eq!(
197 load_config_as_json("model_repo/voicenet_onnx/voicenet/config.pbtxt").unwrap(),
198 json_cfg
199 );
200}