Skip to main content

objectiveai_sdk/agent/mock/
agent.rs

1//! Mock Agent types and validation logic.
2
3use serde::{Deserialize, Serialize};
4use twox_hash::XxHash3_128;
5use schemars::JsonSchema;
6
7/// The base configuration for a Mock Agent (without computed ID).
8#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
9#[schemars(rename = "agent.mock.AgentBase")]
10pub struct AgentBase {
11    /// The upstream provider marker.
12    pub upstream: super::Upstream,
13
14    /// The output mode for vector completions. Ignored for agent completions.
15    pub output_mode: super::OutputMode,
16
17    /// Number of top log probabilities to return (2-20).
18    ///
19    /// **Vector completions only.** Ignored for agent completions.
20    #[serde(skip_serializing_if = "Option::is_none")]
21    #[schemars(extend("omitempty" = true))]
22    #[arbitrary(with = crate::arbitrary_util::arbitrary_option_u64)]
23    pub top_logprobs: Option<u64>,
24
25    /// If true, the mock client will return an error instead of a response.
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    #[schemars(extend("omitempty" = true))]
28    pub error: Option<bool>,
29
30    /// Mock agent mode. Defaults to `default`.
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    #[schemars(extend("omitempty" = true))]
33    pub mode: Option<super::Mode>,
34
35    /// Probability (0-100) that the mock returns an error mid-stream.
36    /// Requires `error` to be `Some(true)`.
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    #[schemars(extend("omitempty" = true))]
39    pub error_probability: Option<u8>,
40
41    /// MCP servers the agent can connect to.
42    #[serde(skip_serializing_if = "Option::is_none")]
43    #[schemars(extend("omitempty" = true))]
44    pub mcp_servers: Option<super::super::McpServers>,
45}
46
47impl AgentBase {
48    /// Normalizes the configuration for deterministic ID computation.
49    pub fn prepare(&mut self) {
50        self.top_logprobs = match self.top_logprobs {
51            Some(0) | Some(1) => None,
52            other => other,
53        };
54        if self.error == Some(true) && self.error_probability == Some(0) {
55            self.error = None;
56            self.error_probability = None;
57        }
58        if self.error == Some(false) {
59            self.error = None;
60        }
61        if self.mode == Some(super::Mode::Default) {
62            self.mode = None;
63        }
64        self.mcp_servers = match self.mcp_servers.take() {
65            Some(mcp_servers) => super::super::mcp::mcp_servers::prepare(mcp_servers),
66            None => None,
67        };
68    }
69
70    /// Validates the configuration.
71    pub fn validate(&self) -> Result<(), String> {
72        if let Some(top_logprobs) = self.top_logprobs
73            && top_logprobs > 20
74        {
75            return Err("`top_logprobs` must be at most 20".to_string());
76        }
77        if self.mode == Some(super::Mode::Invention)
78            && self.output_mode != super::OutputMode::Instruction
79        {
80            return Err(
81                "`mode: invention` is only compatible with `instruction` output mode"
82                    .to_string(),
83            );
84        }
85        if let Some(mcp_servers) = &self.mcp_servers {
86            super::super::mcp::mcp_servers::validate(mcp_servers)?;
87        }
88        if let Some(p) = self.error_probability {
89            if p > 100 {
90                return Err("`error_probability` must be at most 100".to_string());
91            }
92            if self.error != Some(true) {
93                return Err("`error_probability` requires `error` to be true".to_string());
94            }
95        }
96        Ok(())
97    }
98
99    /// Returns the messages as-is.
100    pub fn merged_messages(
101        &self,
102        messages: Vec<super::super::completions::message::Message>,
103    ) -> Vec<super::super::completions::message::Message> {
104        messages
105    }
106
107    /// Computes the deterministic content-addressed ID.
108    pub fn id(&self) -> String {
109        let mut hasher = XxHash3_128::with_seed(0);
110        hasher.write(serde_json::to_string(self).unwrap().as_bytes());
111        format!("{:0>22}", base62::encode(hasher.finish_128()))
112    }
113
114    pub const fn model() -> &'static str {
115        "mock"
116    }
117}
118
119/// A validated Mock Agent with its computed content-addressed ID.
120#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
121#[schemars(rename = "agent.mock.Agent")]
122pub struct Agent {
123    /// The deterministic content-addressed ID (22-character base62 string).
124    pub id: String,
125    /// The normalized configuration.
126    #[serde(flatten)]
127    pub base: AgentBase,
128}
129
130impl TryFrom<AgentBase> for Agent {
131    type Error = String;
132    fn try_from(mut base: AgentBase) -> Result<Self, Self::Error> {
133        base.prepare();
134        base.validate()?;
135        let id = base.id();
136        Ok(Agent { id, base })
137    }
138}