use serde::{Deserialize, Serialize};
use twox_hash::XxHash3_128;
use schemars::JsonSchema;
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[schemars(rename = "agent.mock.AgentBase")]
pub struct AgentBase {
pub upstream: super::Upstream,
pub output_mode: super::OutputMode,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
#[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
pub top_logprobs: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub error: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub mode: Option<super::Mode>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub error_probability: Option<u8>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(extend("omitempty" = true))]
pub mcp_servers: Option<super::super::McpServers>,
}
impl AgentBase {
pub fn prepare(&mut self) {
self.top_logprobs = match self.top_logprobs {
Some(0) | Some(1) => None,
other => other,
};
if self.error == Some(true) && self.error_probability == Some(0) {
self.error = None;
self.error_probability = None;
}
if self.error == Some(false) {
self.error = None;
}
if self.mode == Some(super::Mode::Default) {
self.mode = None;
}
self.mcp_servers = match self.mcp_servers.take() {
Some(mcp_servers) => super::super::mcp::mcp_servers::prepare(mcp_servers),
None => None,
};
}
pub fn validate(&self) -> Result<(), String> {
if let Some(top_logprobs) = self.top_logprobs
&& top_logprobs > 20
{
return Err("`top_logprobs` must be at most 20".to_string());
}
if self.mode == Some(super::Mode::Invention)
&& self.output_mode != super::OutputMode::Instruction
{
return Err(
"`mode: invention` is only compatible with `instruction` output mode"
.to_string(),
);
}
if let Some(mcp_servers) = &self.mcp_servers {
super::super::mcp::mcp_servers::validate(mcp_servers)?;
}
if let Some(p) = self.error_probability {
if p > 100 {
return Err("`error_probability` must be at most 100".to_string());
}
if self.error != Some(true) {
return Err("`error_probability` requires `error` to be true".to_string());
}
}
Ok(())
}
pub fn merged_messages(
&self,
messages: Vec<super::super::completions::message::Message>,
) -> Vec<super::super::completions::message::Message> {
messages
}
pub fn id(&self) -> String {
let mut hasher = XxHash3_128::with_seed(0);
hasher.write(serde_json::to_string(self).unwrap().as_bytes());
format!("{:0>22}", base62::encode(hasher.finish_128()))
}
pub const fn model() -> &'static str {
"mock"
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[schemars(rename = "agent.mock.Agent")]
pub struct Agent {
pub id: String,
#[serde(flatten)]
pub base: AgentBase,
}
impl TryFrom<AgentBase> for Agent {
type Error = String;
fn try_from(mut base: AgentBase) -> Result<Self, Self::Error> {
base.prepare();
base.validate()?;
let id = base.id();
Ok(Agent { id, base })
}
}