use inferd_proto::v2::{Attachment, ContentBlock, MessageV2, ResolvedV2, RoleV2, Tool, ToolCallId};
use serde_json::Value;
pub const MEDIA_MARKER: &str = "<__media__>";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Gemma4Rendered<'a> {
pub prompt: String,
pub attachments: Vec<&'a Attachment>,
}
#[derive(Debug, thiserror::Error)]
pub enum Gemma4RenderError {
#[error(
"messages[{message_index}].content[{block_index}]: attachment {attachment_id:?} not found"
)]
DanglingAttachment {
message_index: usize,
block_index: usize,
attachment_id: String,
},
#[error("messages[{message_index}].content[{block_index}] is an unknown content-block type")]
UnknownBlock {
message_index: usize,
block_index: usize,
},
}
#[derive(Debug, Default)]
pub struct Gemma4Renderer;
impl Gemma4Renderer {
pub fn new() -> Self {
Self
}
pub fn render<'a>(
&self,
resolved: &'a ResolvedV2,
) -> Result<Gemma4Rendered<'a>, Gemma4RenderError> {
let mut prompt = String::with_capacity(512);
let mut attachments: Vec<&Attachment> = Vec::new();
let by_id: std::collections::HashMap<&str, &Attachment> =
resolved.attachments.iter().map(|a| (a.id(), a)).collect();
let tool_name_by_call_id: std::collections::HashMap<&ToolCallId, &str> = resolved
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter_map(|b| match b {
ContentBlock::ToolUse {
tool_call_id, name, ..
} => Some((tool_call_id, name.as_str())),
_ => None,
})
.collect();
prompt.push_str("<bos>");
for (mi, msg) in resolved.messages.iter().enumerate() {
if mi == 0 && !resolved.tools.is_empty() && msg.role != RoleV2::System {
prompt.push_str("<|turn>system\n");
render_tool_declarations(&mut prompt, &resolved.tools);
prompt.push_str("<turn|>\n");
}
render_message(
&mut prompt,
mi,
msg,
&by_id,
&mut attachments,
&resolved.tools,
&tool_name_by_call_id,
)?;
}
prompt.push_str("<|turn>model\n");
Ok(Gemma4Rendered {
prompt,
attachments,
})
}
}
fn render_message<'a>(
out: &mut String,
mi: usize,
msg: &'a MessageV2,
by_id: &std::collections::HashMap<&str, &'a Attachment>,
attachments: &mut Vec<&'a Attachment>,
tools: &[Tool],
tool_name_by_call_id: &std::collections::HashMap<&'a ToolCallId, &'a str>,
) -> Result<(), Gemma4RenderError> {
out.push_str(role_open_tag(msg.role));
out.push('\n');
let is_system = msg.role == RoleV2::System;
for (bi, block) in msg.content.iter().enumerate() {
match block {
ContentBlock::Text { text } => {
out.push_str(text);
}
ContentBlock::Image { attachment_id }
| ContentBlock::Audio { attachment_id }
| ContentBlock::Video { attachment_id } => {
let att = by_id.get(attachment_id.as_str()).ok_or_else(|| {
Gemma4RenderError::DanglingAttachment {
message_index: mi,
block_index: bi,
attachment_id: attachment_id.clone(),
}
})?;
out.push_str(MEDIA_MARKER);
attachments.push(*att);
}
ContentBlock::ToolUse {
tool_call_id: _,
name,
input,
} => {
out.push_str("<|tool_call>call:");
out.push_str(name);
out.push('{');
render_args_inline(out, input);
out.push_str("}<tool_call|>");
}
ContentBlock::ToolResult {
tool_call_id,
content,
} => {
out.push_str("<|tool_response>");
let tool_name = tool_name_by_call_id
.get(tool_call_id)
.copied()
.or_else(|| guess_tool_name_from_tools(tools));
if let Some(name) = tool_name {
out.push_str("response:");
out.push_str(name);
out.push('{');
render_text_only_response(out, content);
out.push('}');
} else {
render_text_only_response(out, content);
}
out.push_str("<tool_response|>");
}
ContentBlock::Unknown => {
return Err(Gemma4RenderError::UnknownBlock {
message_index: mi,
block_index: bi,
});
}
}
}
if is_system && !tools.is_empty() {
render_tool_declarations(out, tools);
}
out.push_str("<turn|>\n");
Ok(())
}
fn role_open_tag(role: RoleV2) -> &'static str {
match role {
RoleV2::System => "<|turn>system",
RoleV2::User => "<|turn>user",
RoleV2::Assistant => "<|turn>model",
}
}
fn render_tool_declarations(out: &mut String, tools: &[Tool]) {
for tool in tools {
out.push_str("<|tool>declaration:");
out.push_str(&tool.name);
out.push('{');
out.push_str("description:<|\"|>");
out.push_str(&tool.description);
out.push_str("<|\"|>,parameters:");
render_schema(out, &tool.input_schema);
out.push('}');
out.push_str("<tool|>");
}
}
fn render_schema(out: &mut String, value: &Value) {
match value {
Value::Null => out.push_str("null"),
Value::Bool(b) => out.push_str(if *b { "true" } else { "false" }),
Value::Number(n) => out.push_str(&n.to_string()),
Value::String(s) => {
out.push_str("<|\"|>");
out.push_str(s);
out.push_str("<|\"|>");
}
Value::Array(items) => {
out.push('[');
for (i, item) in items.iter().enumerate() {
if i > 0 {
out.push(',');
}
render_schema(out, item);
}
out.push(']');
}
Value::Object(map) => {
out.push('{');
let mut first = true;
for (k, v) in map {
if !first {
out.push(',');
}
first = false;
out.push_str(k);
out.push(':');
render_schema(out, v);
}
out.push('}');
}
}
}
fn render_args_inline(out: &mut String, value: &Value) {
if let Value::Object(map) = value {
let mut first = true;
for (k, v) in map {
if !first {
out.push(',');
}
first = false;
out.push_str(k);
out.push(':');
render_schema(out, v);
}
} else {
render_schema(out, value);
}
}
fn guess_tool_name_from_tools(tools: &[Tool]) -> Option<&str> {
if tools.len() == 1 {
Some(tools[0].name.as_str())
} else {
None
}
}
fn render_text_only_response(out: &mut String, content: &[ContentBlock]) {
for block in content {
if let ContentBlock::Text { text } = block {
if let Ok(Value::Object(map)) = serde_json::from_str::<Value>(text) {
let mut first = true;
for (k, v) in map {
if !first {
out.push(',');
}
first = false;
out.push_str(&k);
out.push(':');
render_schema(out, &v);
}
} else {
out.push_str(text);
}
}
}
}