use super::context::{LoadContext, SaveContext};
#[derive(Debug, Clone, Default)]
pub struct ModelOptions {
pub frequency_penalty: Option<f32>,
pub max_output_tokens: Option<i32>,
pub presence_penalty: Option<f32>,
pub seed: Option<i32>,
pub temperature: Option<f32>,
pub top_k: Option<i32>,
pub top_p: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
pub allow_multiple_tool_calls: Option<bool>,
pub additional_properties: serde_json::Value,
}
impl ModelOptions {
pub fn new() -> Self {
Self::default()
}
pub fn from_json(json: &str, ctx: &LoadContext) -> Result<Self, serde_json::Error> {
let value: serde_json::Value = serde_json::from_str(json)?;
Ok(Self::load_from_value(&value, ctx))
}
pub fn from_yaml(yaml: &str, ctx: &LoadContext) -> Result<Self, serde_yaml::Error> {
let value: serde_json::Value = serde_yaml::from_str(yaml)?;
Ok(Self::load_from_value(&value, ctx))
}
pub fn load_from_value(value: &serde_json::Value, ctx: &LoadContext) -> Self {
let value = ctx.process_input(value.clone());
Self {
frequency_penalty: value
.get("frequencyPenalty")
.and_then(|v| v.as_f64())
.map(|n| n as f32),
max_output_tokens: value
.get("maxOutputTokens")
.and_then(|v| v.as_i64())
.map(|n| n as i32),
presence_penalty: value
.get("presencePenalty")
.and_then(|v| v.as_f64())
.map(|n| n as f32),
seed: value.get("seed").and_then(|v| v.as_i64()).map(|n| n as i32),
temperature: value
.get("temperature")
.and_then(|v| v.as_f64())
.map(|n| n as f32),
top_k: value.get("topK").and_then(|v| v.as_i64()).map(|n| n as i32),
top_p: value.get("topP").and_then(|v| v.as_f64()).map(|n| n as f32),
stop_sequences: value
.get("stopSequences")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
}),
allow_multiple_tool_calls: value
.get("allowMultipleToolCalls")
.and_then(|v| v.as_bool()),
additional_properties: value
.get("additionalProperties")
.cloned()
.unwrap_or(serde_json::Value::Null),
}
}
pub fn to_value(&self, ctx: &SaveContext) -> serde_json::Value {
let mut result = serde_json::Map::new();
if let Some(ref val) = self.frequency_penalty {
result.insert(
"frequencyPenalty".to_string(),
serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
);
}
if let Some(ref val) = self.max_output_tokens {
result.insert(
"maxOutputTokens".to_string(),
serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
);
}
if let Some(ref val) = self.presence_penalty {
result.insert(
"presencePenalty".to_string(),
serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
);
}
if let Some(ref val) = self.seed {
result.insert(
"seed".to_string(),
serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
);
}
if let Some(ref val) = self.temperature {
result.insert(
"temperature".to_string(),
serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
);
}
if let Some(ref val) = self.top_k {
result.insert(
"topK".to_string(),
serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
);
}
if let Some(ref val) = self.top_p {
result.insert(
"topP".to_string(),
serde_json::to_value(val).unwrap_or(serde_json::Value::Null),
);
}
if let Some(ref items) = self.stop_sequences {
result.insert(
"stopSequences".to_string(),
serde_json::to_value(items).unwrap_or(serde_json::Value::Null),
);
}
if let Some(val) = self.allow_multiple_tool_calls {
result.insert(
"allowMultipleToolCalls".to_string(),
serde_json::Value::Bool(val),
);
}
if !self.additional_properties.is_null() {
result.insert(
"additionalProperties".to_string(),
self.additional_properties.clone(),
);
}
ctx.process_dict(serde_json::Value::Object(result))
}
pub fn to_json(&self, ctx: &SaveContext) -> Result<String, serde_json::Error> {
serde_json::to_string_pretty(&self.to_value(ctx))
}
pub fn to_yaml(&self, ctx: &SaveContext) -> Result<String, serde_yaml::Error> {
serde_yaml::to_string(&self.to_value(ctx))
}
pub fn as_additional_properties_dict(
&self,
) -> Option<&serde_json::Map<String, serde_json::Value>> {
self.additional_properties.as_object()
}
}