1use std::collections::{HashMap, VecDeque};
2use std::fs;
3use std::io::prelude::*;
4
5use apache_avro::schema::{ArraySchema, DecimalSchema, MapSchema, Name, RecordField, RecordSchema};
6
7use crate::Schema;
8use crate::error::{Error, Result};
9use crate::templates::*;
10
11pub enum Source<'a> {
13 Schema(&'a Schema),
15 Schemas(&'a [Schema]),
17 SchemaStr(&'a str),
19 GlobPattern(&'a str),
21}
22
23#[derive(Debug)]
27pub struct Generator {
28 templater: Templater,
29}
30
31impl Generator {
32 pub fn new() -> Result<Generator> {
34 GeneratorBuilder::new().build()
35 }
36
37 pub fn builder() -> GeneratorBuilder {
39 GeneratorBuilder::new()
40 }
41
42 pub fn generate(&self, source: &Source, output: &mut impl Write) -> Result<()> {
45 match source {
46 Source::Schema(schema) => {
47 let mut deps = deps_stack(schema, vec![]);
48 self.gen_in_order(&mut deps, output)?;
49 }
50
51 Source::Schemas(schemas) => {
52 let mut deps = schemas
53 .iter()
54 .fold(vec![], |deps, schema| deps_stack(schema, deps));
55
56 self.gen_in_order(&mut deps, output)?;
57 }
58
59 Source::SchemaStr(raw_schema) => {
60 let schema = Schema::parse_str(raw_schema)?;
61 let mut deps = deps_stack(&schema, vec![]);
62 self.gen_in_order(&mut deps, output)?;
63 }
64
65 Source::GlobPattern(pattern) => {
66 let mut raw_schemas = vec![];
67 let mut paths = glob::glob(pattern)?.peekable();
68 if paths.peek().is_none() {
69 return Err(Error::GlobPattern(glob::PatternError {
70 pos: 0,
71 msg: "No files with the given glob pattern were found",
72 }));
73 }
74 for path in paths {
75 let path = path.map_err(|e| e.into_error())?;
76 if !path.is_dir() {
77 raw_schemas.push(fs::read_to_string(path)?);
78 }
79 }
80
81 let schemas = &raw_schemas.iter().map(|s| s.as_str()).collect::<Vec<_>>();
82 let schemas = Schema::parse_list(schemas)?;
83 self.generate(&Source::Schemas(&schemas), output)?;
84 }
85 }
86
87 Ok(())
88 }
89
90 fn gen_in_order(&self, deps: &mut Vec<Schema>, output: &mut impl Write) -> Result<()> {
96 let mut gs = GenState::new(deps)?.with_chrono_dates(self.templater.use_chrono_dates);
97
98 if !self.templater.field_overrides.is_empty() {
99 gs = gs.with_field_overrides(deps, &self.templater.field_overrides)?;
101 }
102
103 while let Some(s) = deps.pop() {
104 match s {
105 Schema::Fixed { .. } => {
107 let code = &self.templater.str_fixed(&s)?;
108 output.write_all(code.as_bytes())?
109 }
110 Schema::Enum { .. } => {
111 let code = &self.templater.str_enum(&s)?;
112 output.write_all(code.as_bytes())?
113 }
114
115 Schema::Record { .. } => {
117 let code = &self.templater.str_record(&s, &gs)?;
118 output.write_all(code.as_bytes())?
119 }
120
121 Schema::Array(ArraySchema {
123 items: ref inner, ..
124 }) => {
125 let type_str = array_type(inner, &gs)?;
126 gs.put_type(&s, type_str)
127 }
128 Schema::Map(MapSchema {
129 types: ref inner, ..
130 }) => {
131 let type_str = map_type(inner, &gs)?;
132 gs.put_type(&s, type_str)
133 }
134
135 Schema::Union(ref union) => {
136 if (union.is_nullable() && union.variants().len() > 2)
138 || (!union.is_nullable() && !union.variants().is_empty())
139 {
140 let code = &self.templater.str_union_enum(&s, &gs)?;
141 output.write_all(code.as_bytes())?
142 }
143
144 let type_str = union_type(union, &gs, true)?;
146 gs.put_type(&s, type_str)
147 }
148
149 _ => return Err(Error::Schema(format!("Not a valid root schema: {s:?}"))),
150 }
151 }
152
153 Ok(())
154 }
155}
156
157fn deps_stack(schema: &Schema, mut deps: Vec<Schema>) -> Vec<Schema> {
162 fn push_unique(deps: &mut Vec<Schema>, s: Schema) {
163 let existing = deps.iter().position(|d| match (d, &s) {
170 (Schema::Record(r1), Schema::Record(r2)) => r1.name == r2.name,
171 (Schema::Enum(e1), Schema::Enum(e2)) => e1.name == e2.name,
172 (Schema::Fixed(f1), Schema::Fixed(f2)) => f1.name == f2.name,
173 _ => d == &s,
174 });
175
176 if let Some(i) = existing {
177 deps.remove(i);
178 }
179 deps.push(s);
180 }
181
182 let mut q = VecDeque::new();
183
184 q.push_back(schema);
185 while !q.is_empty() {
186 let s = q.pop_front().unwrap();
187
188 match s {
189 Schema::Enum { .. } => push_unique(&mut deps, s.clone()),
191 Schema::Fixed { .. } => push_unique(&mut deps, s.clone()),
192 Schema::Decimal(DecimalSchema { inner, .. })
193 if matches!(inner.as_ref(), Schema::Fixed { .. }) =>
194 {
195 push_unique(&mut deps, s.clone())
196 }
197
198 Schema::Record(RecordSchema { fields, .. }) => {
200 push_unique(&mut deps, s.clone());
201
202 let by_pos = fields
203 .iter()
204 .map(|f| (f.position, f))
205 .collect::<HashMap<_, _>>();
206 let mut i = 0;
207 while let Some(RecordField { schema: sr, .. }) = by_pos.get(&i) {
208 match sr {
209 Schema::Fixed { .. } => push_unique(&mut deps, sr.clone()),
211 Schema::Enum { .. } => push_unique(&mut deps, sr.clone()),
212
213 Schema::Record { .. } => q.push_back(sr),
215
216 Schema::Map(MapSchema { types: sc, .. })
218 | Schema::Array(ArraySchema { items: sc, .. }) => match sc.as_ref() {
219 Schema::Fixed { .. }
220 | Schema::Enum { .. }
221 | Schema::Record { .. }
222 | Schema::Map(..)
223 | Schema::Array(..)
224 | Schema::Union(..) => {
225 q.push_back(sc);
226 push_unique(&mut deps, s.clone());
227 }
228 _ => (),
229 },
230 Schema::Union(union) => {
231 if (union.is_nullable() && union.variants().len() > 2)
232 || (!union.is_nullable() && !union.variants().is_empty())
233 {
234 push_unique(&mut deps, sr.clone());
235 }
236
237 union.variants().iter().for_each(|sc| match sc {
238 Schema::Fixed { .. }
239 | Schema::Enum { .. }
240 | Schema::Record { .. }
241 | Schema::Map(..)
242 | Schema::Array(..)
243 | Schema::Union(..) => {
244 q.push_back(sc);
245 push_unique(&mut deps, sc.clone());
246 }
247
248 _ => (),
249 });
250 }
251 _ => (),
252 };
253 i += 1;
254 }
255 }
256
257 Schema::Map(MapSchema { types: sc, .. })
259 | Schema::Array(ArraySchema { items: sc, .. }) => match sc.as_ref() {
260 Schema::Fixed { .. }
262 | Schema::Enum { .. }
263 | Schema::Record { .. }
264 | Schema::Map(..)
265 | Schema::Array(..)
266 | Schema::Union(..) => {
267 q.push_back(sc.as_ref());
268 push_unique(&mut deps, s.clone());
269 }
270 _ => push_unique(&mut deps, s.clone()),
272 },
273
274 Schema::Union(union) => {
275 if (union.is_nullable() && union.variants().len() > 2)
276 || (!union.is_nullable() && union.variants().len() > 1)
277 {
278 push_unique(&mut deps, s.clone());
279 }
280
281 union.variants().iter().for_each(|sc| match sc {
282 Schema::Fixed { .. }
284 | Schema::Enum { .. }
285 | Schema::Record { .. }
286 | Schema::Map(..)
287 | Schema::Array(..)
288 | Schema::Union(..) => {
289 q.push_back(sc);
290 push_unique(&mut deps, s.clone());
291 }
292 _ => push_unique(&mut deps, s.clone()),
294 });
295 }
296
297 _ => (),
299 }
300 }
301
302 deps
303}
304
305pub struct GeneratorBuilder {
307 precision: usize,
308 nullable: bool,
309 use_avro_rs_unions: bool,
310 use_chrono_dates: bool,
311 derive_builders: bool,
312 impl_schemas: ImplementAvroSchema,
313 extra_derives: Vec<String>,
314 field_overrides: HashMap<Name, Vec<FieldOverride>>,
315}
316
317impl Default for GeneratorBuilder {
318 fn default() -> Self {
319 Self {
320 precision: 3,
321 nullable: false,
322 use_avro_rs_unions: false,
323 use_chrono_dates: false,
324 derive_builders: false,
325 impl_schemas: ImplementAvroSchema::None,
326 extra_derives: vec![],
327 field_overrides: HashMap::new(),
328 }
329 }
330}
331
332#[derive(PartialEq, Debug, Clone, Copy, Default)]
333#[cfg_attr(feature = "build-cli", derive(clap::ValueEnum))]
334pub enum ImplementAvroSchema {
338 Derive,
345
346 CopyBuildSchema,
354
355 #[default]
359 None,
360}
361
362impl std::fmt::Display for ImplementAvroSchema {
363 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
364 write!(f, "{self:?}")
365 }
366}
367
368#[derive(Debug, Clone)]
376
377pub struct FieldOverride {
378 pub schema: Name,
380 pub field: String,
382 pub docstring: Option<String>,
384 pub type_name: Option<String>,
390 pub implements_eq: Option<bool>,
394 pub serde_with: Option<String>,
398 pub default: Option<String>,
402}
403
404impl GeneratorBuilder {
405 pub fn new() -> GeneratorBuilder {
407 GeneratorBuilder::default()
408 }
409
410 pub fn precision(mut self, precision: usize) -> GeneratorBuilder {
412 self.precision = precision;
413 self
414 }
415
416 pub fn nullable(mut self, nullable: bool) -> GeneratorBuilder {
420 self.nullable = nullable;
421 self
422 }
423
424 pub fn use_avro_rs_unions(mut self, use_avro_rs_unions: bool) -> GeneratorBuilder {
429 self.use_avro_rs_unions = use_avro_rs_unions;
430 self
431 }
432
433 pub fn use_chrono_dates(mut self, use_chrono_dates: bool) -> GeneratorBuilder {
435 self.use_chrono_dates = use_chrono_dates;
436 self
437 }
438
439 pub fn derive_builders(mut self, derive_builders: bool) -> GeneratorBuilder {
443 self.derive_builders = derive_builders;
444 self
445 }
446
447 pub fn implement_avro_schema(mut self, impl_schemas: ImplementAvroSchema) -> GeneratorBuilder {
456 self.impl_schemas = impl_schemas;
457 self
458 }
459
460 pub fn extra_derives(mut self, extra_derives: Vec<String>) -> GeneratorBuilder {
464 self.extra_derives = extra_derives;
465 self
466 }
467
468 pub fn override_fields(mut self, overrides: Vec<FieldOverride>) -> GeneratorBuilder {
472 for over in overrides {
473 self.field_overrides
474 .entry(over.schema.clone())
475 .or_default()
476 .push(over);
477 }
478 self
479 }
480
481 pub fn override_field(mut self, over: FieldOverride) -> GeneratorBuilder {
485 self.field_overrides
486 .entry(over.schema.clone())
487 .or_default()
488 .push(over);
489 self
490 }
491
492 pub fn build(self) -> Result<Generator> {
494 let mut templater = Templater::new()?;
495 templater.precision = self.precision;
496 templater.nullable = self.nullable;
497 templater.use_avro_rs_unions = self.use_avro_rs_unions;
498 templater.use_chrono_dates = self.use_chrono_dates;
499 templater.derive_builders = self.derive_builders;
500 templater.derive_schemas = self.impl_schemas == ImplementAvroSchema::Derive;
501 templater.impl_schemas = self.impl_schemas == ImplementAvroSchema::CopyBuildSchema;
502 templater.extra_derives = self.extra_derives;
503 templater.field_overrides = self.field_overrides;
504 Ok(Generator { templater })
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use apache_avro::schema::{EnumSchema, Name};
511 use pretty_assertions::assert_eq;
512
513 use super::*;
514
515 #[test]
516 fn deps() {
517 let raw_schema = r#"
518{
519 "type": "record",
520 "name": "User",
521 "fields": [
522 {"name": "name", "type": "string", "default": "unknown"},
523 {"name": "address",
524 "type": {
525 "type": "record",
526 "name": "Address",
527 "fields": [
528 {"name": "city", "type": "string", "default": "unknown"},
529 {"name": "country",
530 "type": {"type": "enum", "name": "Country", "symbols": ["FR", "JP"]}
531 }
532 ]
533 }
534 }
535 ]
536}
537"#;
538
539 let schema = Schema::parse_str(raw_schema).unwrap();
540 let mut deps = deps_stack(&schema, vec![]);
541
542 let s = deps.pop().unwrap();
543 assert!(
544 matches!(s, Schema::Enum(EnumSchema{ name: Name { ref name, ..}, ..}) if name == "Country")
545 );
546
547 let s = deps.pop().unwrap();
548 assert!(
549 matches!(s, Schema::Record(RecordSchema{ name: Name { ref name, ..}, ..}) if name == "Address")
550 );
551
552 let s = deps.pop().unwrap();
553 assert!(
554 matches!(s, Schema::Record(RecordSchema{ name: Name { ref name, ..}, ..}) if name == "User")
555 );
556
557 let s = deps.pop();
558 assert!(s.is_none());
559 }
560
561 #[test]
562 fn cross_deps() -> std::result::Result<(), Box<dyn std::error::Error>> {
563 use std::fs::File;
564 use std::io::Write;
565 use tempfile::tempdir;
566
567 let dir = tempdir()?;
568
569 let mut schema_a_file = File::create(dir.path().join("schema_a.avsc"))?;
570 let schema_a_str = r#"
571{
572 "name": "A",
573 "type": "record",
574 "fields": [ {"name": "field_one", "type": "float"} ]
575}
576"#;
577 schema_a_file.write_all(schema_a_str.as_bytes())?;
578
579 let mut schema_b_file = File::create(dir.path().join("schema_b.avsc"))?;
580 let schema_b_str = r#"
581{
582 "name": "B",
583 "type": "record",
584 "fields": [ {"name": "field_one", "type": "A"} ]
585}
586"#;
587 schema_b_file.write_all(schema_b_str.as_bytes())?;
588
589 let expected = r#"
590#[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)]
591pub struct B {
592 pub field_one: A,
593}
594
595#[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)]
596pub struct A {
597 pub field_one: f32,
598}
599"#;
600
601 let pattern = format!("{}/*.avsc", dir.path().display());
602 let source = Source::GlobPattern(pattern.as_str());
603 let g = Generator::new()?;
604 let mut buf = vec![];
605 g.generate(&source, &mut buf)?;
606 let res = String::from_utf8(buf)?;
607
608 assert_eq!(expected, res);
609
610 drop(schema_a_file);
611 drop(schema_b_file);
612 dir.close()?;
613 Ok(())
614 }
615}