use aws_sdk_bedrockruntime::types::{
ContentBlock, SpecificToolChoice, Tool, ToolChoice, ToolConfiguration, ToolInputSchema,
ToolSpecification,
};
use aws_smithy_types::Document;
use crate::llm::error::LlmError;
pub fn build_tool_config(
schema_name: &str,
json_schema: &serde_json::Value,
) -> Result<ToolConfiguration, LlmError> {
let doc = json_to_document(json_schema).ok_or_else(|| {
LlmError::Validation(format!(
"response_schema for {schema_name:?} must be a JSON object"
))
})?;
let input_schema = ToolInputSchema::Json(doc);
let tool_spec = ToolSpecification::builder()
.name(schema_name)
.description(format!(
"Structured output tool for {schema_name}. \
You MUST call this tool to return your response."
))
.input_schema(input_schema)
.build()
.map_err(|e| LlmError::Validation(format!("build ToolSpecification: {e}")))?;
let tool = Tool::ToolSpec(tool_spec);
let specific = SpecificToolChoice::builder()
.name(schema_name)
.build()
.map_err(|e| LlmError::Validation(format!("build SpecificToolChoice: {e}")))?;
let tool_choice = ToolChoice::Tool(specific);
let config = ToolConfiguration::builder()
.tools(tool)
.tool_choice(tool_choice)
.build()
.map_err(|e| LlmError::Validation(format!("build ToolConfiguration: {e}")))?;
Ok(config)
}
pub fn extract_tool_use_json(
resp: &aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
) -> Option<String> {
let msg = resp.output()?.as_message().ok()?;
for block in msg.content() {
if let ContentBlock::ToolUse(tu) = block {
let doc = tu.input();
return document_to_json_string(doc);
}
}
None
}
pub fn json_to_document(value: &serde_json::Value) -> Option<Document> {
match value {
serde_json::Value::Object(map) => {
let doc_map: std::collections::HashMap<String, Document> = map
.iter()
.map(|(k, v)| (k.clone(), json_value_to_doc(v)))
.collect();
Some(Document::Object(doc_map))
}
_ => None, }
}
fn json_value_to_doc(v: &serde_json::Value) -> Document {
match v {
serde_json::Value::Null => Document::Null,
serde_json::Value::Bool(b) => Document::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Document::Number(aws_smithy_types::Number::NegInt(i))
} else if let Some(u) = n.as_u64() {
Document::Number(aws_smithy_types::Number::PosInt(u))
} else {
Document::Number(aws_smithy_types::Number::Float(n.as_f64().unwrap_or(0.0)))
}
}
serde_json::Value::String(s) => Document::String(s.clone()),
serde_json::Value::Array(arr) => {
Document::Array(arr.iter().map(json_value_to_doc).collect())
}
serde_json::Value::Object(map) => {
let doc_map: std::collections::HashMap<String, Document> = map
.iter()
.map(|(k, v)| (k.clone(), json_value_to_doc(v)))
.collect();
Document::Object(doc_map)
}
}
}
pub fn document_to_json_string(doc: &Document) -> Option<String> {
let value = doc_to_json_value(doc);
serde_json::to_string(&value).ok()
}
fn doc_to_json_value(doc: &Document) -> serde_json::Value {
match doc {
Document::Null => serde_json::Value::Null,
Document::Bool(b) => serde_json::Value::Bool(*b),
Document::Number(n) => match n {
aws_smithy_types::Number::PosInt(u) => {
serde_json::Value::Number(serde_json::Number::from(*u))
}
aws_smithy_types::Number::NegInt(i) => {
serde_json::Value::Number(serde_json::Number::from(*i))
}
aws_smithy_types::Number::Float(f) => serde_json::Number::from_f64(*f)
.map(serde_json::Value::Number)
.unwrap_or(serde_json::Value::Null),
},
Document::String(s) => serde_json::Value::String(s.clone()),
Document::Array(arr) => {
serde_json::Value::Array(arr.iter().map(doc_to_json_value).collect())
}
Document::Object(map) => {
let obj: serde_json::Map<String, serde_json::Value> = map
.iter()
.map(|(k, v)| (k.clone(), doc_to_json_value(v)))
.collect();
serde_json::Value::Object(obj)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn json_to_document_converts_object() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"verdict": { "type": "string" }
},
"required": ["verdict"]
});
let doc = json_to_document(&schema).expect("must convert object schema");
assert!(
matches!(doc, Document::Object(_)),
"top-level must be Document::Object"
);
}
#[test]
fn json_to_document_rejects_non_object() {
let arr = serde_json::json!([1, 2, 3]);
assert!(
json_to_document(&arr).is_none(),
"array must not convert to Document"
);
let s = serde_json::json!("hello");
assert!(
json_to_document(&s).is_none(),
"string must not convert to Document"
);
}
#[test]
fn document_roundtrip_preserves_values() {
let mut map = std::collections::HashMap::new();
map.insert(
"verdict".to_string(),
Document::String("APPROVE".to_string()),
);
map.insert(
"confidence".to_string(),
Document::Number(aws_smithy_types::Number::Float(0.95)),
);
let doc = Document::Object(map);
let json_str = document_to_json_string(&doc).expect("must serialise");
let parsed: serde_json::Value = serde_json::from_str(&json_str).expect("must parse");
assert_eq!(parsed["verdict"], "APPROVE");
let conf = parsed["confidence"]
.as_f64()
.expect("confidence must be float");
assert!((conf - 0.95).abs() < 1e-9);
}
#[test]
fn build_tool_config_succeeds_for_valid_schema() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"verdict": { "type": "string" }
},
"required": ["verdict"]
});
let result = build_tool_config("review_output", &schema);
assert!(
result.is_ok(),
"build_tool_config must succeed for a valid object schema"
);
}
#[test]
fn build_tool_config_rejects_non_object_schema() {
let schema = serde_json::json!("not-an-object");
let err =
build_tool_config("review_output", &schema).expect_err("non-object schema must fail");
assert!(
matches!(err, LlmError::Validation(_)),
"expected Validation error, got: {err:?}"
);
}
}