1use refrain_core::{Op, Pattern, Refrain};
8
9use crate::{AdapterCaps, AdapterErr, EmitCtx, ExtractedRefrain, RefrainAdapter};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum CodeLang {
13 Python,
14 Rust,
15}
16
17pub struct CodeAdapter {
18 pub lang: CodeLang,
19}
20
21impl CodeAdapter {
22 pub fn new(lang: CodeLang) -> Self {
23 Self { lang }
24 }
25
26 fn emit_python(&self, r: &Refrain) -> String {
27 let mut out = String::new();
28 out.push_str("# Auto-generated by refrain-adapters code emitter.\n");
29 out.push_str("from refrain_py import _native\n\n");
30 out.push_str("REFRAIN_SRC = (\n");
31 out.push_str(" \"(refrain ");
32 out.push_str(&r.name);
33 for (kind, p) in r.stages() {
34 out.push_str(" (");
35 out.push_str(kind.as_str());
36 out.push(' ');
37 out.push_str(&pattern_to_sexp(p));
38 out.push(')');
39 }
40 out.push_str(")\"\n)\n\n");
41 out.push_str("refrain_json = _native.parse_refrain(REFRAIN_SRC)\n");
42 out
43 }
44
45 fn emit_rust(&self, r: &Refrain) -> String {
46 let mut out = String::new();
47 out.push_str("// Auto-generated by refrain-adapters code emitter.\n");
48 out.push_str("use refrain_core::parse;\n\n");
49 out.push_str("pub fn refrain() -> refrain_core::Refrain {\n");
50 out.push_str(" parse(r#\"(refrain ");
51 out.push_str(&r.name);
52 for (kind, p) in r.stages() {
53 out.push_str(" (");
54 out.push_str(kind.as_str());
55 out.push(' ');
56 out.push_str(&pattern_to_sexp(p));
57 out.push(')');
58 }
59 out.push_str(")\"#).expect(\"valid refrain\")\n}\n");
60 out
61 }
62}
63
64fn pattern_to_sexp(p: &Pattern) -> String {
65 match p {
66 Pattern::Op(op) => op_to_sexp(op),
67 Pattern::Seq(items) => items
68 .iter()
69 .map(pattern_to_sexp)
70 .collect::<Vec<_>>()
71 .join(" "),
72 }
73}
74
75fn op_to_sexp(op: &Op) -> String {
76 match op {
77 Op::Note { pitch, dur } => format!("(note {} {})", pitch, dur),
78 Op::Loop { count, body } => format!("(loop {} {})", count, pattern_to_sexp(body)),
79 Op::Diff { x, t } => format!("(dy/dx {} {})", x, t),
80 Op::Quotient { rels } => format!("(quotient {})", rels.join(" ")),
81 Op::Sym(s) => s.clone(),
82 Op::Call { head, args } => {
83 let mut s = String::from("(");
84 s.push_str(head);
85 for a in args {
86 s.push(' ');
87 s.push_str(&pattern_to_sexp(a));
88 }
89 s.push(')');
90 s
91 }
92 }
93}
94
95impl RefrainAdapter for CodeAdapter {
96 fn name(&self) -> &str {
97 match self.lang {
98 CodeLang::Python => "code.python",
99 CodeLang::Rust => "code.rust",
100 }
101 }
102
103 fn emit(&self, refrain: &ExtractedRefrain, _ctx: &EmitCtx) -> Result<Vec<u8>, AdapterErr> {
104 let s = match self.lang {
105 CodeLang::Python => self.emit_python(refrain.refrain),
106 CodeLang::Rust => self.emit_rust(refrain.refrain),
107 };
108 Ok(s.into_bytes())
109 }
110
111 fn capabilities(&self) -> AdapterCaps {
112 AdapterCaps {
113 realtime: false,
114 differentiable: false,
115 }
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use refrain_core::parse;
123
124 #[test]
125 fn python_roundtrip_via_sexp() {
126 let r = parse("(refrain mel (territorialize (loop 4 (note C4 q))))").unwrap();
127 let a = CodeAdapter::new(CodeLang::Python);
128 let ex = ExtractedRefrain { refrain: &r };
129 let s = String::from_utf8(a.emit(&ex, &EmitCtx::default()).unwrap()).unwrap();
130 assert!(s.contains("from refrain_py import _native"));
131 assert!(s.contains("(refrain mel"));
132 assert!(s.contains("(territorialize (loop 4 (note C4 q)))"));
133 }
134
135 #[test]
136 fn rust_roundtrip_via_sexp() {
137 let r = parse("(refrain m (deterritorialize (dy/dx x t)))").unwrap();
138 let a = CodeAdapter::new(CodeLang::Rust);
139 let ex = ExtractedRefrain { refrain: &r };
140 let s = String::from_utf8(a.emit(&ex, &EmitCtx::default()).unwrap()).unwrap();
141 assert!(s.contains("use refrain_core::parse"));
142 assert!(s.contains("(deterritorialize (dy/dx x t))"));
143 }
144
145 #[test]
146 fn emitted_python_is_parseable_back_to_same_refrain() {
147 let src = "(refrain n \
148 (territorialize (loop 4 (note C4 q))) \
149 (deterritorialize (dy/dx i t)) \
150 (reterritorialize (quotient ~r ~s)))";
151 let original = parse(src).unwrap();
152 let a = CodeAdapter::new(CodeLang::Python);
153 let emitted = String::from_utf8(
154 a.emit(
155 &ExtractedRefrain { refrain: &original },
156 &EmitCtx::default(),
157 )
158 .unwrap(),
159 )
160 .unwrap();
161 let start = emitted.find("\"(refrain").expect("refrain literal found") + 1;
163 let end = emitted[start..].find(")\"\n").expect("closing quote") + start + 1;
164 let literal = &emitted[start..end];
165 let reparsed = parse(literal).unwrap();
166 assert_eq!(reparsed, original);
167 }
168
169 #[test]
170 fn empty_refrain_emits_minimal_python() {
171 let r = parse("(refrain x)").unwrap();
172 let a = CodeAdapter::new(CodeLang::Python);
173 let s = String::from_utf8(
174 a.emit(&ExtractedRefrain { refrain: &r }, &EmitCtx::default())
175 .unwrap(),
176 )
177 .unwrap();
178 assert!(s.contains("(refrain x)"));
179 }
180
181 #[test]
182 fn names_distinguish_languages() {
183 assert_eq!(CodeAdapter::new(CodeLang::Python).name(), "code.python");
184 assert_eq!(CodeAdapter::new(CodeLang::Rust).name(), "code.rust");
185 }
186}