use super::error::CompilationError;
#[derive(Debug, Clone)]
pub struct CodeTemplate {
forbid_unsafe: bool,
}
impl Default for CodeTemplate {
fn default() -> Self {
Self {
forbid_unsafe: true,
}
}
}
impl CodeTemplate {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_forbid_unsafe(mut self, forbid: bool) -> Self {
self.forbid_unsafe = forbid;
self
}
pub fn wrap(&self, code: &str) -> Result<String, CompilationError> {
if code.trim().is_empty() {
return Err(CompilationError::template_failed("code cannot be empty"));
}
let unsafe_attr = if self.forbid_unsafe {
"#![forbid(unsafe_code)]"
} else {
""
};
Ok(format!(
r#"#![no_std]
#![no_main]
{unsafe_attr}
extern crate alloc;
use alloc::string::String;
use alloc::format;
use alloc::vec::Vec;
use hyperlight_guest_bin::{{guest_function, host_function}};
/// Host function for logging (optional).
#[host_function]
fn host_log(msg: String) -> i32;
/// The code execution function called by the host.
///
/// Takes an input string and returns the result as a string.
#[guest_function("run_code")]
fn run_code(input: String) -> String {{
// Agent-generated code begins here
{code}
// Agent-generated code ends here
}}
"#
))
}
#[must_use]
pub fn cargo_toml(&self) -> String {
r#"[package]
name = "rust_code_guest"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["staticlib"]
[dependencies]
hyperlight-guest-bin = "0.12"
[profile.release]
panic = "abort"
lto = true
opt-level = "s"
"#
.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::compiler::CompilationErrorKind;
#[test]
fn template_default_forbids_unsafe() {
let template = CodeTemplate::default();
assert!(template.forbid_unsafe);
}
#[test]
fn template_new_equals_default() {
let t1 = CodeTemplate::new();
let t2 = CodeTemplate::default();
assert_eq!(t1.forbid_unsafe, t2.forbid_unsafe);
}
#[test]
fn template_with_forbid_unsafe_false() {
let template = CodeTemplate::new().with_forbid_unsafe(false);
assert!(!template.forbid_unsafe);
}
#[test]
fn template_wrap_includes_no_std() {
let template = CodeTemplate::default();
let wrapped = template.wrap("input.to_string()").unwrap();
assert!(wrapped.contains("#![no_std]"));
}
#[test]
fn template_wrap_includes_no_main() {
let template = CodeTemplate::default();
let wrapped = template.wrap("input.to_string()").unwrap();
assert!(wrapped.contains("#![no_main]"));
}
#[test]
fn template_wrap_includes_forbid_unsafe() {
let template = CodeTemplate::default();
let wrapped = template.wrap("input.to_string()").unwrap();
assert!(wrapped.contains("#![forbid(unsafe_code)]"));
}
#[test]
fn template_wrap_without_forbid_unsafe() {
let template = CodeTemplate::new().with_forbid_unsafe(false);
let wrapped = template.wrap("input.to_string()").unwrap();
assert!(!wrapped.contains("#![forbid(unsafe_code)]"));
}
#[test]
fn template_wrap_includes_alloc() {
let template = CodeTemplate::default();
let wrapped = template.wrap("input.to_string()").unwrap();
assert!(wrapped.contains("extern crate alloc"));
}
#[test]
fn template_wrap_includes_guest_function() {
let template = CodeTemplate::default();
let wrapped = template.wrap("input.to_string()").unwrap();
assert!(wrapped.contains("#[guest_function(\"run_code\")]"));
}
#[test]
fn template_wrap_includes_host_function() {
let template = CodeTemplate::default();
let wrapped = template.wrap("input.to_string()").unwrap();
assert!(wrapped.contains("#[host_function]"));
assert!(wrapped.contains("fn host_log"));
}
#[test]
fn template_wrap_includes_user_code() {
let template = CodeTemplate::default();
let code = "input.chars().rev().collect()";
let wrapped = template.wrap(code).unwrap();
assert!(wrapped.contains(code));
}
#[test]
fn template_wrap_empty_code_fails() {
let template = CodeTemplate::default();
let result = template.wrap("");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err().kind(),
CompilationErrorKind::TemplateFailed { .. }
));
}
#[test]
fn template_wrap_whitespace_only_fails() {
let template = CodeTemplate::default();
let result = template.wrap(" \n\t ");
assert!(result.is_err());
}
#[test]
fn template_wrap_multiline_code() {
let template = CodeTemplate::default();
let code = r#"
let parts: Vec<&str> = input.split(',').collect();
parts.join("-")
"#;
let result = template.wrap(code);
assert!(result.is_ok());
assert!(result.unwrap().contains("split"));
}
#[test]
fn cargo_toml_has_package() {
let template = CodeTemplate::default();
let toml = template.cargo_toml();
assert!(toml.contains("[package]"));
assert!(toml.contains("name = \"rust_code_guest\""));
}
#[test]
fn cargo_toml_has_lib_staticlib() {
let template = CodeTemplate::default();
let toml = template.cargo_toml();
assert!(toml.contains("[lib]"));
assert!(toml.contains("crate-type = [\"staticlib\"]"));
}
#[test]
fn cargo_toml_has_hyperlight_dependency() {
let template = CodeTemplate::default();
let toml = template.cargo_toml();
assert!(toml.contains("[dependencies]"));
assert!(toml.contains("hyperlight-guest-bin"));
}
#[test]
fn cargo_toml_has_release_profile() {
let template = CodeTemplate::default();
let toml = template.cargo_toml();
assert!(toml.contains("[profile.release]"));
assert!(toml.contains("panic = \"abort\""));
assert!(toml.contains("lto = true"));
}
#[test]
fn template_is_clone() {
let template1 = CodeTemplate::new().with_forbid_unsafe(false);
let template2 = template1.clone();
assert_eq!(template1.forbid_unsafe, template2.forbid_unsafe);
}
}