use std::io::{BufRead, Write};
use std::sync::atomic::{AtomicI64, Ordering};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::debug;
use crate::mcp::protocol::RequestId;
static SAMPLING_ID: AtomicI64 = AtomicI64::new(9001);
fn next_sampling_id() -> RequestId {
RequestId::Number(SAMPLING_ID.fetch_add(1, Ordering::Relaxed))
}
#[derive(Debug, Serialize)]
pub struct SamplingRequest {
pub jsonrpc: &'static str,
pub id: RequestId,
pub method: &'static str,
pub params: SamplingParams,
}
#[derive(Debug, Clone, Serialize)]
pub struct SamplingParams {
pub messages: Vec<SamplingMessage>,
#[serde(rename = "maxTokens")]
pub max_tokens: u32,
#[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
pub model_preferences: Option<ModelPreferences>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SamplingMessage {
pub role: SamplingRole,
pub content: SamplingContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SamplingRole {
User,
Assistant,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum SamplingContent {
Text { text: String },
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: &'static str,
},
}
impl SamplingContent {
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
#[must_use]
pub fn png(bytes: &[u8]) -> Self {
use base64::Engine as _;
let data = base64::engine::general_purpose::STANDARD.encode(bytes);
Self::Image {
data,
mime_type: "image/png",
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelPreferences {
#[serde(
rename = "intelligencePriority",
skip_serializing_if = "Option::is_none"
)]
pub intelligence_priority: Option<f32>,
#[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
pub speed_priority: Option<f32>,
}
#[derive(Debug, Deserialize)]
pub struct SamplingResult {
pub role: SamplingRole,
pub content: SamplingResultContent,
#[serde(default)]
pub model: Option<String>,
#[serde(rename = "stopReason", default)]
pub stop_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum SamplingResultContent {
Text { text: String },
Image { data: String },
}
impl SamplingResultContent {
#[must_use]
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text { text } => Some(text),
Self::Image { .. } => None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum SamplingError {
#[error("client does not support sampling")]
NotSupported,
#[error("I/O error during sampling exchange: {0}")]
Io(#[from] std::io::Error),
#[error("failed to serialise sampling request: {0}")]
Serialise(serde_json::Error),
#[error("failed to parse sampling response: {0}")]
Parse(serde_json::Error),
#[error("sampling response contained an error: code={code}, message={message}")]
RpcError { code: i32, message: String },
#[error("sampling response had no text content")]
NoTextContent,
}
pub fn create_message<W, R>(
out: &mut W,
input: &mut R,
messages: Vec<SamplingMessage>,
max_tokens: u32,
system_prompt: Option<String>,
) -> Result<String, SamplingError>
where
W: Write,
R: BufRead,
{
let id = next_sampling_id();
let request = SamplingRequest {
jsonrpc: "2.0",
id,
method: "sampling/createMessage",
params: SamplingParams {
messages,
max_tokens,
system_prompt,
model_preferences: None,
},
};
let json = serde_json::to_string(&request).map_err(SamplingError::Serialise)?;
debug!(bytes = json.len(), "sending sampling/createMessage");
writeln!(out, "{json}")?;
out.flush()?;
let mut line = String::new();
input.read_line(&mut line)?;
debug!(bytes = line.len(), "received sampling response");
parse_sampling_response(&line)
}
fn parse_sampling_response(line: &str) -> Result<String, SamplingError> {
let value: Value = serde_json::from_str(line.trim()).map_err(SamplingError::Parse)?;
if let Some(err) = value.get("error") {
let code = err["code"].as_i64().unwrap_or(-1) as i32;
let message = err["message"].as_str().unwrap_or("unknown").to_string();
return Err(SamplingError::RpcError { code, message });
}
let result: SamplingResult =
serde_json::from_value(value.get("result").cloned().unwrap_or(Value::Null))
.map_err(SamplingError::Parse)?;
result
.content
.as_text()
.map(str::to_string)
.ok_or(SamplingError::NoTextContent)
}
#[derive(Debug, Clone, Copy)]
pub struct SamplingContext {
available: bool,
}
impl SamplingContext {
#[must_use]
pub const fn available() -> Self {
Self { available: true }
}
#[must_use]
pub const fn unavailable() -> Self {
Self { available: false }
}
#[must_use]
pub const fn is_available(self) -> bool {
self.available
}
}
impl From<bool> for SamplingContext {
fn from(available: bool) -> Self {
Self { available }
}
}
#[must_use]
pub fn locate_element_messages(
description: &str,
screenshot_png: &[u8],
) -> (Vec<SamplingMessage>, Option<String>) {
let system = "You are a macOS UI automation assistant. When shown a screenshot, \
identify the requested UI element and respond with a JSON object containing \
the element's approximate center coordinates: \
{\"found\": true, \"x\": <int>, \"y\": <int>, \"description\": \"<what you see>\"}. \
If the element is not found, respond with: \
{\"found\": false, \"description\": \"<what you see instead>\"}.";
let messages = vec![
SamplingMessage {
role: SamplingRole::User,
content: SamplingContent::png(screenshot_png),
},
SamplingMessage {
role: SamplingRole::User,
content: SamplingContent::text(format!(
"Find this UI element in the screenshot: {description}"
)),
},
];
(messages, Some(system.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn sampling_content_text_serialises_with_type_tag() {
let content = SamplingContent::text("hello");
let v: serde_json::Value = serde_json::to_value(&content).unwrap();
assert_eq!(v["type"], "text");
assert_eq!(v["text"], "hello");
}
#[test]
fn sampling_content_image_serialises_with_mime_type() {
let content = SamplingContent::png(&[0u8]);
let v: serde_json::Value = serde_json::to_value(&content).unwrap();
assert_eq!(v["type"], "image");
assert_eq!(v["mimeType"], "image/png");
assert!(v["data"].as_str().is_some());
}
#[test]
fn sampling_request_serialises_required_fields() {
let req = SamplingRequest {
jsonrpc: "2.0",
id: RequestId::Number(9001),
method: "sampling/createMessage",
params: SamplingParams {
messages: vec![SamplingMessage {
role: SamplingRole::User,
content: SamplingContent::text("describe"),
}],
max_tokens: 256,
system_prompt: None,
model_preferences: None,
},
};
let v: serde_json::Value = serde_json::to_value(&req).unwrap();
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["id"], 9001);
assert_eq!(v["method"], "sampling/createMessage");
assert_eq!(v["params"]["maxTokens"], 256);
assert!(v["params"]["messages"].as_array().is_some());
assert!(v["params"].get("systemPrompt").is_none());
assert!(v["params"].get("modelPreferences").is_none());
}
#[test]
fn sampling_request_includes_system_prompt_when_set() {
let req = SamplingRequest {
jsonrpc: "2.0",
id: RequestId::Number(9002),
method: "sampling/createMessage",
params: SamplingParams {
messages: vec![],
max_tokens: 128,
system_prompt: Some("be concise".into()),
model_preferences: None,
},
};
let v: serde_json::Value = serde_json::to_value(&req).unwrap();
assert_eq!(v["params"]["systemPrompt"], "be concise");
}
#[test]
fn parse_response_extracts_text_from_result() {
let line = r#"{"jsonrpc":"2.0","id":9001,"result":{"role":"assistant","content":{"type":"text","text":"I see a Save button."}}}"#;
let text = parse_sampling_response(line).unwrap();
assert_eq!(text, "I see a Save button.");
}
#[test]
fn parse_response_returns_rpc_error_on_error_object() {
let line =
r#"{"jsonrpc":"2.0","id":9001,"error":{"code":-32600,"message":"Invalid Request"}}"#;
let err = parse_sampling_response(line).unwrap_err();
assert!(matches!(err, SamplingError::RpcError { code: -32600, .. }));
}
#[test]
fn parse_response_returns_parse_error_on_bad_json() {
let err = parse_sampling_response("not json").unwrap_err();
assert!(matches!(err, SamplingError::Parse(_)));
}
#[test]
fn parse_response_returns_no_text_on_image_content() {
let line = r#"{"jsonrpc":"2.0","id":9001,"result":{"role":"assistant","content":{"type":"image","data":"abc123"}}}"#;
let err = parse_sampling_response(line).unwrap_err();
assert!(matches!(err, SamplingError::NoTextContent));
}
#[test]
fn create_message_writes_request_and_reads_response() {
let response_line = "{\"jsonrpc\":\"2.0\",\"id\":9999,\"result\":{\"role\":\"assistant\",\"content\":{\"type\":\"text\",\"text\":\"found it\"}}}\n";
let mut input = Cursor::new(response_line.as_bytes().to_vec());
let mut output = Vec::<u8>::new();
let text = create_message(
&mut output,
&mut input,
vec![SamplingMessage {
role: SamplingRole::User,
content: SamplingContent::text("find the Save button"),
}],
256,
None,
)
.unwrap();
assert_eq!(text, "found it");
let written = std::str::from_utf8(&output).unwrap();
let v: serde_json::Value = serde_json::from_str(written.trim()).unwrap();
assert_eq!(v["method"], "sampling/createMessage");
assert_eq!(v["params"]["maxTokens"], 256);
}
#[test]
fn sampling_context_available_reports_true() {
assert!(SamplingContext::available().is_available());
}
#[test]
fn sampling_context_unavailable_reports_false() {
assert!(!SamplingContext::unavailable().is_available());
}
#[test]
fn locate_element_messages_produces_two_user_turns() {
let (messages, system_prompt) = locate_element_messages("Save button", &[0u8; 4]);
assert_eq!(messages.len(), 2);
assert!(matches!(messages[0].role, SamplingRole::User));
assert!(matches!(messages[0].content, SamplingContent::Image { .. }));
assert!(matches!(messages[1].content, SamplingContent::Text { .. }));
assert!(system_prompt.is_some());
}
#[test]
fn next_sampling_id_is_monotonically_increasing() {
let a = match next_sampling_id() {
RequestId::Number(n) => n,
RequestId::String(_) => panic!("expected number"),
};
let b = match next_sampling_id() {
RequestId::Number(n) => n,
RequestId::String(_) => panic!("expected number"),
};
assert!(b > a);
}
}