tritonserver_rs/
parameter.rs

1use std::{fs::File, path::Path, ptr::null_mut};
2
3use crate::{error::Error, sys, to_cstring};
4
5/// Types of parameters recognized by TRITONSERVER.
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
7#[repr(u32)]
8pub enum TritonParameterType {
9    String = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_STRING,
10    Int = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_INT,
11    Bool = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BOOL,
12    Double = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_DOUBLE,
13    Bytes = sys::TRITONSERVER_parametertype_enum_TRITONSERVER_PARAMETER_BYTES,
14}
15
16/// Enum representation of Parameter content.
17#[derive(Debug, Clone)]
18pub enum ParameterContent {
19    String(String),
20    Int(i64),
21    Bool(bool),
22    Double(f64),
23    Bytes(Vec<u8>),
24}
25
26/// Parameter of the [Server](crate::Server) or [Response](crate::Response).
27#[derive(Debug)]
28pub struct Parameter {
29    pub(crate) ptr: *mut sys::TRITONSERVER_Parameter,
30    pub name: String,
31    pub content: ParameterContent,
32}
33
34unsafe impl Send for Parameter {}
35
36impl Parameter {
37    /// Create new Parameter.
38    pub fn new<N: AsRef<str>>(name: N, value: ParameterContent) -> Result<Self, Error> {
39        let c_name = to_cstring(&name)?;
40        let ptr = match &value {
41            ParameterContent::Bool(v) => unsafe {
42                sys::TRITONSERVER_ParameterNew(
43                    c_name.as_ptr(),
44                    TritonParameterType::Bool as _,
45                    v as *const bool as *const _,
46                )
47            },
48            ParameterContent::Int(v) => unsafe {
49                sys::TRITONSERVER_ParameterNew(
50                    c_name.as_ptr(),
51                    TritonParameterType::Int as _,
52                    v as *const i64 as *const _,
53                )
54            },
55            ParameterContent::String(v) => {
56                let v = to_cstring(v)?;
57                unsafe {
58                    sys::TRITONSERVER_ParameterNew(
59                        c_name.as_ptr(),
60                        TritonParameterType::String as _,
61                        v.as_ptr() as *const _,
62                    )
63                }
64            }
65            ParameterContent::Double(v) => unsafe {
66                sys::TRITONSERVER_ParameterNew(
67                    c_name.as_ptr(),
68                    TritonParameterType::Double as _,
69                    v as *const f64 as *const _,
70                )
71            },
72            ParameterContent::Bytes(v) => unsafe {
73                sys::TRITONSERVER_ParameterBytesNew(
74                    c_name.as_ptr(),
75                    v.as_ptr() as *const _,
76                    v.len() as _,
77                )
78            },
79        };
80
81        Ok(Self {
82            ptr,
83            name: name.as_ref().to_string(),
84            content: value,
85        })
86    }
87
88    /// Create String Parameter of model config with exact version of the model. \
89    /// `config`: model config.pbtxt as json value.
90    /// Check [load_config_as_json] to permutate .pbtxt config to json value. \
91    /// If [Options::model_control_mode](crate::options::Options::model_control_mode) set as EXPLICIT and the result of this method is passed to [crate::Server::load_model_with_parametrs],
92    /// the server will load only that exact model and only that exact version of it.
93    pub fn from_config_with_exact_version(
94        mut config: serde_json::Value,
95        version: i64,
96    ) -> Result<Self, Error> {
97        config["version_policy"] = serde_json::json!({"specific": { "versions": [version]}});
98        Parameter::new("config", ParameterContent::String(config.to_string()))
99    }
100}
101
102impl Clone for Parameter {
103    fn clone(&self) -> Self {
104        Parameter::new(self.name.clone(), self.content.clone()).unwrap_or_else(|err| {
105            log::warn!("Error cloning parameter: {err}. Result will be empty, do not use it.");
106            Parameter {
107                ptr: null_mut(),
108                name: String::new(),
109                content: ParameterContent::String(String::new()),
110            }
111        })
112    }
113}
114
115impl Drop for Parameter {
116    fn drop(&mut self) {
117        if !self.ptr.is_null() {
118            unsafe { sys::TRITONSERVER_ParameterDelete(self.ptr) }
119        }
120    }
121}
122
123fn hjson_to_json(value: serde_hjson::Value) -> serde_json::Value {
124    match value {
125        serde_hjson::Value::Null => serde_json::Value::Null,
126        serde_hjson::Value::U64(v) => serde_json::Value::from(v),
127        serde_hjson::Value::I64(v) => serde_json::Value::from(v),
128        serde_hjson::Value::F64(v) => serde_json::Value::from(v),
129        serde_hjson::Value::Bool(v) => serde_json::Value::from(v),
130        serde_hjson::Value::String(v) => serde_json::Value::from(v),
131
132        serde_hjson::Value::Array(v) => {
133            serde_json::Value::from_iter(v.into_iter().map(hjson_to_json))
134        }
135        serde_hjson::Value::Object(v) => serde_json::Value::from_iter(
136            v.into_iter()
137                .map(|(key, value)| (key, hjson_to_json(value))),
138        ),
139    }
140}
141
142/// Load config.pbtxt from the `config_path` and parse it to json value. \
143/// Might be useful if it is required to run model with altered config.
144/// In this case String [Parameter] with name 'config' and the result of this method as data should be created
145/// and passed to [Server::load_model_with_parametrs](crate::Server::load_model_with_parametrs) ([Options::model_control_mode](crate::options::Options::model_control_mode) set as EXPLICIT required).
146/// Check realization of [Parameter::from_config_with_exact_version] as an example. \
147/// **Note (Subject to change)**: congig must be in [hjson format](https://hjson.github.io/).
148pub fn load_config_as_json<P: AsRef<Path>>(config_path: P) -> Result<serde_json::Value, Error> {
149    let content = File::open(config_path).map_err(|err| {
150        Error::new(
151            crate::error::ErrorCode::InvalidArg,
152            format!("Error opening the config file: {err}"),
153        )
154    })?;
155    let value = serde_hjson::from_reader::<_, serde_hjson::Value>(&content).map_err(|err| {
156        Error::new(
157            crate::error::ErrorCode::InvalidArg,
158            format!("Error parsing the config file as hjson: {err}"),
159        )
160    })?;
161    Ok(hjson_to_json(value))
162}
163
164#[test]
165fn test_config_to_json() {
166    let json_cfg = serde_json::json!({
167        "name": "voicenet",
168        "platform": "onnxruntime_onnx",
169        "input": [
170            {
171                "data_type": "TYPE_FP32",
172                "name": "input",
173                "dims": [512, 160000]
174            }
175        ],
176        "output": [
177            {
178                "data_type": "TYPE_FP32",
179                "name": "output",
180                "dims": [512, 512]
181            }
182        ],
183        "instance_group": [
184            {
185                "count": 2,
186                "kind": "KIND_CPU"
187            }
188        ],
189        "optimization": { "execution_accelerators": {
190            "cpu_execution_accelerator" : [ {
191                "name" : "openvino"
192            } ]
193        }}
194    });
195
196    assert_eq!(
197        load_config_as_json("model_repo/voicenet_onnx/voicenet/config.pbtxt").unwrap(),
198        json_cfg
199    );
200}