use std::collections::BTreeMap;
use std::rc::Rc;
use serde_json::{json, Value as JsonValue};
use crate::schema::json_to_vm_value;
use crate::stdlib::host::{dispatch_host_call_bridge, dispatch_mock_host_call};
use crate::value::{VmError, VmValue};
pub const SAMPLING_METHOD: &str = "sampling/createMessage";
#[derive(Debug, Clone)]
struct SamplingRequest {
messages: Vec<JsonValue>,
system: Option<String>,
max_tokens: i64,
temperature: Option<f64>,
stop_sequences: Option<Vec<String>>,
model_preferences: Option<JsonValue>,
tools: Option<JsonValue>,
tool_choice: Option<JsonValue>,
thinking: Option<JsonValue>,
metadata: Option<JsonValue>,
include_context: Option<String>,
}
#[derive(Debug, Clone)]
enum ApprovalDecision {
Accept(BTreeMap<String, VmValue>),
Decline(String),
}
pub async fn dispatch_inbound_sampling(server_name: &str, request: &JsonValue) -> JsonValue {
let id = request.get("id").cloned().unwrap_or(JsonValue::Null);
let params = request.get("params").cloned().unwrap_or_else(|| json!({}));
let parsed = match parse_sampling_request(¶ms) {
Ok(p) => p,
Err(detail) => return crate::jsonrpc::error_response(id, -32602, &detail),
};
let approval = ask_host_approval(server_name, ¶ms).await;
let overrides = match approval {
ApprovalDecision::Accept(map) => map,
ApprovalDecision::Decline(reason) => {
return crate::jsonrpc::error_response_with_data(
id,
-32603,
&format!("Sampling declined: {reason}"),
json!({
"type": "mcp.samplingDeclined",
"method": SAMPLING_METHOD,
"reason": reason,
}),
);
}
};
match run_llm_call(&parsed, overrides).await {
Ok(outcome) => crate::jsonrpc::response(id, build_spec_response(outcome, &parsed)),
Err(detail) => crate::jsonrpc::error_response_with_data(
id,
-32000,
&format!("Sampling failed: {detail}"),
json!({
"type": "mcp.samplingFailed",
"method": SAMPLING_METHOD,
"reason": detail,
}),
),
}
}
fn parse_sampling_request(params: &JsonValue) -> Result<SamplingRequest, String> {
let object = params
.as_object()
.ok_or_else(|| "sampling params must be a JSON object".to_string())?;
let messages = match object.get("messages") {
Some(JsonValue::Array(items)) => items.clone(),
Some(_) => return Err("sampling params 'messages' must be an array".into()),
None => return Err("sampling params 'messages' is required".into()),
};
if messages.is_empty() {
return Err("sampling params 'messages' must not be empty".into());
}
for (idx, message) in messages.iter().enumerate() {
let role = message
.get("role")
.and_then(|value| value.as_str())
.ok_or_else(|| format!("sampling messages[{idx}].role is required"))?;
if !matches!(role, "user" | "assistant" | "system") {
return Err(format!(
"sampling messages[{idx}].role must be 'user'/'assistant'/'system' (got {role:?})"
));
}
if message.get("content").is_none() {
return Err(format!("sampling messages[{idx}].content is required"));
}
}
let system = object
.get("systemPrompt")
.and_then(|value| value.as_str())
.filter(|value| !value.is_empty())
.map(str::to_string);
let max_tokens = object
.get("maxTokens")
.and_then(|value| value.as_i64())
.ok_or_else(|| {
"sampling params 'maxTokens' is required and must be an integer".to_string()
})?;
if max_tokens <= 0 {
return Err(format!(
"sampling params 'maxTokens' must be positive (got {max_tokens})"
));
}
let temperature =
match object.get("temperature") {
Some(JsonValue::Number(n)) => Some(n.as_f64().ok_or_else(|| {
"sampling params 'temperature' must be a finite number".to_string()
})?),
Some(JsonValue::Null) | None => None,
Some(_) => return Err("sampling params 'temperature' must be a number".into()),
};
let stop_sequences = match object.get("stopSequences") {
Some(JsonValue::Array(items)) => {
let mut out = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
let s = item.as_str().ok_or_else(|| {
format!("sampling params 'stopSequences[{idx}]' must be a string")
})?;
out.push(s.to_string());
}
Some(out)
}
Some(JsonValue::Null) | None => None,
Some(_) => return Err("sampling params 'stopSequences' must be an array".into()),
};
let include_context = object
.get("includeContext")
.and_then(|value| value.as_str())
.map(str::to_string);
let tool_choice = object
.get("toolChoice")
.or_else(|| object.get("tool_choice"))
.cloned();
Ok(SamplingRequest {
messages,
system,
max_tokens,
temperature,
stop_sequences,
model_preferences: object.get("modelPreferences").cloned(),
tools: object.get("tools").cloned(),
tool_choice,
thinking: object.get("thinking").cloned(),
metadata: object.get("metadata").cloned(),
include_context,
})
}
async fn ask_host_approval(server_name: &str, params: &JsonValue) -> ApprovalDecision {
let mut bridge_params: BTreeMap<String, VmValue> = BTreeMap::new();
bridge_params.insert("server".to_string(), VmValue::String(Rc::from(server_name)));
bridge_params.insert("params".to_string(), json_to_vm_value(params));
let result = dispatch_mock_host_call("mcp", "sample", &bridge_params)
.or_else(|| dispatch_host_call_bridge("mcp", "sample", &bridge_params));
let raw = match result {
Some(Ok(value)) => value,
Some(Err(error)) => {
return ApprovalDecision::Decline(host_error_to_string(error));
}
None => {
return ApprovalDecision::Decline(
"no host bridge installed for ('mcp', 'sample')".into(),
);
}
};
coerce_bridge_response(raw)
}
fn coerce_bridge_response(value: VmValue) -> ApprovalDecision {
match value {
VmValue::Nil => ApprovalDecision::Decline("host bridge returned nil".into()),
VmValue::Bool(false) => ApprovalDecision::Decline("host bridge declined".into()),
VmValue::Bool(true) => ApprovalDecision::Accept(BTreeMap::new()),
VmValue::Dict(dict) => {
let map = dict.as_ref().clone();
match map.get("action").and_then(|v| match v {
VmValue::String(s) => Some(s.to_string()),
_ => None,
}) {
Some(action) if action == "decline" || action == "cancel" => {
let reason = map
.get("message")
.or_else(|| map.get("reason"))
.map(VmValue::display)
.unwrap_or_else(|| "host bridge declined".to_string());
ApprovalDecision::Decline(reason)
}
Some(action) if action == "accept" => {
let overrides = map
.get("options")
.and_then(|v| match v {
VmValue::Dict(d) => Some(d.as_ref().clone()),
_ => None,
})
.unwrap_or_default();
ApprovalDecision::Accept(overrides)
}
Some(other) => ApprovalDecision::Decline(format!(
"host bridge returned unknown action {other:?}"
)),
None => {
ApprovalDecision::Accept(map)
}
}
}
other => ApprovalDecision::Decline(format!(
"host bridge returned unsupported value: {}",
other.display()
)),
}
}
fn host_error_to_string(error: VmError) -> String {
match error {
VmError::Thrown(VmValue::String(s)) => s.to_string(),
VmError::Thrown(other) => other.display(),
VmError::Runtime(s) | VmError::TypeError(s) => s,
other => format!("{other:?}"),
}
}
#[derive(Debug, Clone)]
struct LlmOutcome {
text: String,
model: String,
}
async fn run_llm_call(
parsed: &SamplingRequest,
overrides: BTreeMap<String, VmValue>,
) -> Result<LlmOutcome, String> {
let (vm_args, options_dict) = build_llm_call_args(parsed, overrides);
let opts = crate::llm::extract_llm_options(&vm_args).map_err(host_error_to_string)?;
let result = crate::llm::execute_llm_call(opts, Some(options_dict), None)
.await
.map_err(host_error_to_string)?;
extract_assistant_outcome(&result)
}
fn extract_assistant_outcome(result: &VmValue) -> Result<LlmOutcome, String> {
match result {
VmValue::String(s) => Ok(LlmOutcome {
text: s.to_string(),
model: String::new(),
}),
VmValue::Dict(d) => {
let text = match d.get("text") {
Some(VmValue::String(s)) => s.to_string(),
Some(other) => other.display(),
None => return Err("llm_call result missing 'text' field".into()),
};
let model = d.get("model").map(VmValue::display).unwrap_or_default();
Ok(LlmOutcome { text, model })
}
other => Ok(LlmOutcome {
text: other.display(),
model: String::new(),
}),
}
}
fn build_llm_call_args(
parsed: &SamplingRequest,
overrides: BTreeMap<String, VmValue>,
) -> (Vec<VmValue>, BTreeMap<String, VmValue>) {
let mut options: BTreeMap<String, VmValue> = BTreeMap::new();
let messages_vm: Vec<VmValue> = parsed.messages.iter().map(json_to_vm_value).collect();
options.insert("messages".to_string(), VmValue::List(Rc::new(messages_vm)));
options.insert("max_tokens".to_string(), VmValue::Int(parsed.max_tokens));
if let Some(temperature) = parsed.temperature {
options.insert("temperature".to_string(), VmValue::Float(temperature));
}
if let Some(stop) = parsed.stop_sequences.as_ref() {
let stop_vm: Vec<VmValue> = stop
.iter()
.map(|s| VmValue::String(Rc::from(s.as_str())))
.collect();
options.insert("stop".to_string(), VmValue::List(Rc::new(stop_vm)));
}
if let Some(hint) = pick_model_hint(parsed.model_preferences.as_ref()) {
options.insert(
"model".to_string(),
VmValue::String(Rc::from(hint.as_str())),
);
}
if let Some(tools) = parsed.tools.as_ref() {
options.insert("tools".to_string(), json_to_vm_value(tools));
}
if let Some(tool_choice) = parsed.tool_choice.as_ref() {
options.insert("tool_choice".to_string(), json_to_vm_value(tool_choice));
}
if let Some(thinking) = parsed.thinking.as_ref() {
options.insert("thinking".to_string(), json_to_vm_value(thinking));
}
if let Some(metadata) = parsed.metadata.as_ref() {
options.insert("metadata".to_string(), json_to_vm_value(metadata));
}
if let Some(include_context) = parsed.include_context.as_ref() {
options.insert(
"include_context".to_string(),
VmValue::String(Rc::from(include_context.as_str())),
);
}
for (key, value) in overrides {
options.insert(key, value);
}
let system_value = parsed
.system
.as_ref()
.map(|s| VmValue::String(Rc::from(s.as_str())))
.unwrap_or(VmValue::Nil);
let args = vec![
VmValue::String(Rc::from("")),
system_value,
VmValue::Dict(Rc::new(options.clone())),
];
(args, options)
}
fn pick_model_hint(prefs: Option<&JsonValue>) -> Option<String> {
let prefs = prefs?;
let hints = prefs.get("hints")?.as_array()?;
for hint in hints {
if let Some(name) = hint.get("name").and_then(|value| value.as_str()) {
if !name.is_empty() {
return Some(name.to_string());
}
}
}
None
}
fn build_spec_response(outcome: LlmOutcome, parsed: &SamplingRequest) -> JsonValue {
let stop_reason = if parsed.stop_sequences.is_some() {
"stopSequence"
} else {
"endTurn"
};
let model = if outcome.model.is_empty() {
pick_model_hint(parsed.model_preferences.as_ref()).unwrap_or_default()
} else {
outcome.model
};
json!({
"role": "assistant",
"content": {
"type": "text",
"text": outcome.text,
},
"model": model,
"stopReason": stop_reason,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn minimal_request() -> JsonValue {
json!({
"messages": [
{"role": "user", "content": {"type": "text", "text": "hi"}}
],
"maxTokens": 64,
})
}
#[test]
fn parse_rejects_missing_messages() {
let err = parse_sampling_request(&json!({"maxTokens": 1})).unwrap_err();
assert!(err.contains("messages"));
}
#[test]
fn parse_rejects_empty_messages() {
let err = parse_sampling_request(&json!({"messages": [], "maxTokens": 1})).unwrap_err();
assert!(err.contains("must not be empty"));
}
#[test]
fn parse_rejects_missing_max_tokens() {
let err = parse_sampling_request(&json!({
"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]
}))
.unwrap_err();
assert!(err.contains("maxTokens"));
}
#[test]
fn parse_rejects_zero_max_tokens() {
let err = parse_sampling_request(&json!({
"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}],
"maxTokens": 0,
}))
.unwrap_err();
assert!(err.contains("positive"));
}
#[test]
fn parse_rejects_unknown_role() {
let err = parse_sampling_request(&json!({
"messages": [{"role": "tool", "content": {}}],
"maxTokens": 1,
}))
.unwrap_err();
assert!(err.contains("'user'/'assistant'/'system'"));
}
#[test]
fn parse_extracts_optional_fields() {
let parsed = parse_sampling_request(&json!({
"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}],
"maxTokens": 32,
"systemPrompt": "be brief",
"temperature": 0.2,
"stopSequences": ["END"],
"modelPreferences": {"hints": [{"name": "claude-3-5-sonnet"}]},
"includeContext": "thisServer",
"metadata": {"trace": "abc"},
}))
.unwrap();
assert_eq!(parsed.max_tokens, 32);
assert_eq!(parsed.system.as_deref(), Some("be brief"));
assert_eq!(parsed.temperature, Some(0.2));
assert_eq!(
parsed.stop_sequences.as_deref(),
Some(&["END".to_string()][..])
);
assert_eq!(parsed.include_context.as_deref(), Some("thisServer"));
assert_eq!(
pick_model_hint(parsed.model_preferences.as_ref()),
Some("claude-3-5-sonnet".to_string())
);
}
#[test]
fn pick_model_hint_picks_first_non_empty() {
let prefs = json!({"hints": [{"name": ""}, {"name": "gpt-4"}]});
assert_eq!(pick_model_hint(Some(&prefs)), Some("gpt-4".to_string()));
}
#[test]
fn pick_model_hint_returns_none_for_empty_chain() {
assert!(pick_model_hint(None).is_none());
assert!(pick_model_hint(Some(&json!({"hints": []}))).is_none());
assert!(pick_model_hint(Some(&json!({}))).is_none());
}
#[test]
fn coerce_bridge_response_nil_declines() {
match coerce_bridge_response(VmValue::Nil) {
ApprovalDecision::Decline(_) => {}
other => panic!("expected decline, got {other:?}"),
}
}
#[test]
fn coerce_bridge_response_true_accepts_with_no_overrides() {
match coerce_bridge_response(VmValue::Bool(true)) {
ApprovalDecision::Accept(map) => assert!(map.is_empty()),
other => panic!("expected accept, got {other:?}"),
}
}
#[test]
fn coerce_bridge_response_accept_with_options() {
let mut dict = BTreeMap::new();
dict.insert("action".to_string(), VmValue::String(Rc::from("accept")));
let mut options = BTreeMap::new();
options.insert("provider".to_string(), VmValue::String(Rc::from("mock")));
dict.insert("options".to_string(), VmValue::Dict(Rc::new(options)));
match coerce_bridge_response(VmValue::Dict(Rc::new(dict))) {
ApprovalDecision::Accept(map) => {
assert_eq!(
map.get("provider").map(|v| v.display()).as_deref(),
Some("mock")
);
}
other => panic!("expected accept, got {other:?}"),
}
}
#[test]
fn coerce_bridge_response_decline_with_message() {
let mut dict = BTreeMap::new();
dict.insert("action".to_string(), VmValue::String(Rc::from("decline")));
dict.insert(
"message".to_string(),
VmValue::String(Rc::from("rate limit")),
);
match coerce_bridge_response(VmValue::Dict(Rc::new(dict))) {
ApprovalDecision::Decline(reason) => assert_eq!(reason, "rate limit"),
other => panic!("expected decline, got {other:?}"),
}
}
#[test]
fn coerce_bridge_response_bare_dict_is_overrides() {
let mut dict = BTreeMap::new();
dict.insert("provider".to_string(), VmValue::String(Rc::from("mock")));
match coerce_bridge_response(VmValue::Dict(Rc::new(dict))) {
ApprovalDecision::Accept(map) => {
assert_eq!(
map.get("provider").map(|v| v.display()).as_deref(),
Some("mock")
);
}
other => panic!("expected accept, got {other:?}"),
}
}
fn outcome(text: &str, model: &str) -> LlmOutcome {
LlmOutcome {
text: text.to_string(),
model: model.to_string(),
}
}
#[test]
fn build_spec_response_flags_stop_sequence() {
let parsed = parse_sampling_request(&json!({
"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}],
"maxTokens": 4,
"stopSequences": ["END"],
}))
.unwrap();
let response = build_spec_response(outcome("done", "actual-model"), &parsed);
assert_eq!(response["stopReason"], json!("stopSequence"));
assert_eq!(response["role"], json!("assistant"));
assert_eq!(response["content"]["type"], json!("text"));
assert_eq!(response["content"]["text"], json!("done"));
assert_eq!(response["model"], json!("actual-model"));
}
#[test]
fn build_spec_response_default_stop_reason_is_end_turn() {
let parsed = parse_sampling_request(&minimal_request()).unwrap();
let response = build_spec_response(outcome("done", ""), &parsed);
assert_eq!(response["stopReason"], json!("endTurn"));
}
#[test]
fn build_spec_response_falls_back_to_hint_when_outcome_model_missing() {
let parsed = parse_sampling_request(&json!({
"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}],
"maxTokens": 4,
"modelPreferences": {"hints": [{"name": "claude-3-5-sonnet"}]},
}))
.unwrap();
let response = build_spec_response(outcome("done", ""), &parsed);
assert_eq!(response["model"], json!("claude-3-5-sonnet"));
}
#[tokio::test(flavor = "current_thread")]
async fn dispatch_with_no_bridge_declines() {
let request = json!({
"jsonrpc": "2.0",
"id": "s-1",
"method": SAMPLING_METHOD,
"params": minimal_request(),
});
let response = dispatch_inbound_sampling("mock", &request).await;
assert_eq!(response["id"], json!("s-1"));
assert_eq!(response["error"]["code"], json!(-32603));
assert_eq!(
response["error"]["data"]["type"],
json!("mcp.samplingDeclined")
);
}
#[tokio::test(flavor = "current_thread")]
async fn dispatch_with_invalid_params_returns_invalid_params() {
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": SAMPLING_METHOD,
"params": {"messages": []},
});
let response = dispatch_inbound_sampling("mock", &request).await;
assert_eq!(response["id"], json!(1));
assert_eq!(response["error"]["code"], json!(-32602));
}
struct ApproveSamplingBridge {
overrides: BTreeMap<String, VmValue>,
}
impl crate::stdlib::host::HostCallBridge for ApproveSamplingBridge {
fn dispatch(
&self,
capability: &str,
operation: &str,
_params: &BTreeMap<String, VmValue>,
) -> Result<Option<VmValue>, VmError> {
if capability == "mcp" && operation == "sample" {
let mut envelope: BTreeMap<String, VmValue> = BTreeMap::new();
envelope.insert("action".to_string(), VmValue::String(Rc::from("accept")));
envelope.insert(
"options".to_string(),
VmValue::Dict(Rc::new(self.overrides.clone())),
);
Ok(Some(VmValue::Dict(Rc::new(envelope))))
} else {
Ok(None)
}
}
}
#[tokio::test(flavor = "current_thread")]
async fn dispatch_with_mock_bridge_routes_to_llm_call() {
crate::llm::mock::reset_llm_mock_state();
crate::llm::mock::push_llm_mock(crate::llm::mock::LlmMock {
text: "sampled text".to_string(),
tool_calls: Vec::new(),
match_pattern: None,
consume_on_match: true,
input_tokens: None,
output_tokens: None,
cache_read_tokens: None,
cache_write_tokens: None,
thinking: None,
thinking_summary: None,
stop_reason: None,
model: "mock-model".to_string(),
provider: Some("mock".to_string()),
blocks: None,
error: None,
});
let mut overrides: BTreeMap<String, VmValue> = BTreeMap::new();
overrides.insert("provider".to_string(), VmValue::String(Rc::from("mock")));
overrides.insert("model".to_string(), VmValue::String(Rc::from("mock-model")));
crate::stdlib::host::set_host_call_bridge(Rc::new(ApproveSamplingBridge { overrides }));
let request = json!({
"jsonrpc": "2.0",
"id": 7,
"method": SAMPLING_METHOD,
"params": {
"messages": [
{"role": "user", "content": {"type": "text", "text": "ping"}}
],
"maxTokens": 32,
"modelPreferences": {"hints": [{"name": "mock-model"}]},
},
});
let response = dispatch_inbound_sampling("test-server", &request).await;
crate::llm::mock::reset_llm_mock_state();
crate::stdlib::host::clear_host_call_bridge();
assert_eq!(response["id"], json!(7));
assert!(
response.get("result").is_some(),
"expected success result, got {response:?}"
);
assert_eq!(response["result"]["role"], json!("assistant"));
assert_eq!(response["result"]["content"]["type"], json!("text"));
assert_eq!(response["result"]["content"]["text"], json!("sampled text"));
assert_eq!(response["result"]["stopReason"], json!("endTurn"));
assert_eq!(response["result"]["model"], json!("mock-model"));
}
}