bicycle/
utils.rs

1/*
2BicycleDB is a protobuf-defined database management system.
3
4Copyright (C) 2024 Ordinary Labs
5
6This program is free software: you can redistribute it and/or modify
7it under the terms of the GNU Affero General Public License as
8published by the Free Software Foundation, either version 3 of the
9License, or (at your option) any later version.
10
11This program is distributed in the hope that it will be useful,
12but WITHOUT ANY WARRANTY; without even the implied warranty of
13MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14GNU Affero General Public License for more details.
15
16You should have received a copy of the GNU Affero General Public License
17along with this program.  If not, see <http://www.gnu.org/licenses/>.
18*/
19
20use prost_types::{
21    field_descriptor_proto::{self, Type},
22    DescriptorProto, FieldDescriptorProto,
23};
24
25#[derive(Debug)]
26pub struct Property {
27    pub _type: String,
28    pub name: String,
29    pub number: i32,
30}
31
32#[derive(Debug)]
33pub struct Model {
34    pub name: String,
35    pub properties: Vec<Property>,
36    pub nested_models: Vec<Model>,
37}
38
39pub fn construct_model(
40    message: &DescriptorProto,
41    should_check_pk: bool,
42) -> Result<Model, &'static str> {
43    let mut has_valid_pk = false;
44    let mut properties: Vec<Property> = vec![];
45
46    for field in message.field.iter() {
47        let repeated = match field.label() {
48            field_descriptor_proto::Label::Repeated => "repeated ",
49            _ => "",
50        };
51
52        properties.push(Property {
53            _type: format!("{}{}", repeated, get_usable_type(&field, &message)),
54            name: field.name().to_string(),
55            number: field.number(),
56        });
57
58        if field.name() == "pk" {
59            if field.number() == 1 {
60                match field.r#type() {
61                    Type::String => {
62                        has_valid_pk = true;
63                    }
64                    _ => eprintln!("missing 'string pk = 1;'"),
65                }
66            }
67        }
68    }
69
70    if should_check_pk && !has_valid_pk {
71        return Err("model does not include `string pk = 1;`");
72    }
73
74    let mut nested_models: Vec<Model> = vec![];
75
76    for nested_message in message.nested_type.iter() {
77        if nested_message.name().ends_with("Entry") {
78            continue;
79        }
80
81        let nested_model = construct_model(&nested_message, false)?;
82        nested_models.push(nested_model);
83    }
84
85    Ok(Model {
86        name: message.name().to_string(),
87        properties,
88        nested_models,
89    })
90}
91
92pub fn get_complex_type(field: &FieldDescriptorProto, message: &DescriptorProto) -> String {
93    if let Some(type_name) = field.type_name().split('.').last() {
94        for nested_type in message.nested_type.iter() {
95            if nested_type.name() == type_name {
96                // check for map<,> type
97                if type_name.ends_with("Entry") {
98                    let mut key_type = "".to_string();
99                    let mut val_type = "".to_string();
100
101                    for field in nested_type.field.iter() {
102                        if field.name() == "key" {
103                            key_type = get_usable_type(&field, &message);
104                        } else if field.name() == "value" {
105                            val_type = get_usable_type(&field, &message);
106                        }
107                    }
108
109                    return format!("map<{}, {}>", key_type, val_type);
110                } else {
111                    return type_name.to_string();
112                }
113            }
114        }
115
116        return "".to_string();
117    }
118
119    "".to_string()
120}
121
122pub fn get_usable_type(field: &FieldDescriptorProto, message: &DescriptorProto) -> String {
123    match field.r#type() {
124        Type::Double => "double".to_string(),
125        Type::Float => "float".to_string(),
126        Type::Int32 => "int32".to_string(),
127        Type::Int64 => "int64".to_string(),
128        Type::Uint32 => "uint32".to_string(),
129        Type::Uint64 => "uint64".to_string(),
130        Type::Sint32 => "sint32".to_string(),
131        Type::Sint64 => "sint64".to_string(),
132        Type::Fixed32 => "fixed32".to_string(),
133        Type::Fixed64 => "fixed64".to_string(),
134        Type::Sfixed32 => "sfixed32".to_string(),
135        Type::Sfixed64 => "sfixed64".to_string(),
136        Type::Bool => "bool".to_string(),
137        Type::String => "string".to_string(),
138        Type::Bytes => "bytes".to_string(),
139        Type::Message => get_complex_type(field, message),
140
141        // !! handle explicitly
142        Type::Enum => "ENUM".to_string(),
143        Type::Group => "GROUP IS NOT SUPPORTED".to_string(),
144    }
145}