1use crate::error::{QueryError, Result};
2use std::any::TypeId;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum SqlType {
8 BigInt,
9 Integer,
10 SmallInt,
11 Real,
12 DoublePrecision,
13 Text,
14 Boolean,
15 Timestamp,
16 Json,
17 Bytea,
18}
19
20impl SqlType {
21 pub fn to_sql(&self) -> &'static str {
22 match self {
23 SqlType::BigInt => "BIGINT",
24 SqlType::Integer => "INTEGER",
25 SqlType::SmallInt => "SMALLINT",
26 SqlType::Real => "REAL",
27 SqlType::DoublePrecision => "DOUBLE PRECISION",
28 SqlType::Text => "TEXT",
29 SqlType::Boolean => "BOOLEAN",
30 SqlType::Timestamp => "TIMESTAMP",
31 SqlType::Json => "JSONB",
32 SqlType::Bytea => "BYTEA",
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct ColumnDef {
40 pub name: String,
41 pub sql_type: SqlType,
42 pub nullable: bool,
43 pub default: Option<String>,
44}
45
46#[derive(Debug, Clone)]
48pub struct TableSchema {
49 pub name: String,
50 pub columns: Vec<ColumnDef>,
51 pub indexes: Vec<IndexDef>,
52}
53
54#[derive(Debug, Clone)]
56pub struct IndexDef {
57 pub name: String,
58 pub columns: Vec<String>,
59 pub unique: bool,
60}
61
62pub struct ComponentRegistration {
64 pub type_id: TypeId,
65 pub name: String,
66 pub schema: TableSchema,
67}
68
69pub struct SchemaGenerator {
71 registrations: HashMap<TypeId, ComponentRegistration>,
72 component_names: HashMap<String, TypeId>,
73}
74
75impl SchemaGenerator {
76 pub fn new() -> Self {
77 Self {
78 registrations: HashMap::new(),
79 component_names: HashMap::new(),
80 }
81 }
82
83 pub fn register<T>(&mut self, name: &str, fields: Vec<(&str, SqlType, bool)>) -> Result<()>
85 where
86 T: 'static,
87 {
88 let type_id = TypeId::of::<T>();
89
90 if self.registrations.contains_key(&type_id) {
91 return Err(QueryError::Schema(format!(
92 "Component {} already registered",
93 name
94 )));
95 }
96
97 let mut columns = vec![
98 ColumnDef {
99 name: "entity_id".to_string(),
100 sql_type: SqlType::BigInt,
101 nullable: false,
102 default: None,
103 },
104 ];
105
106 for (field_name, field_type, nullable) in fields {
107 columns.push(ColumnDef {
108 name: field_name.to_string(),
109 sql_type: field_type,
110 nullable,
111 default: None,
112 });
113 }
114
115 columns.push(ColumnDef {
117 name: "_tx2_updated_at".to_string(),
118 sql_type: SqlType::Timestamp,
119 nullable: false,
120 default: Some("CURRENT_TIMESTAMP".to_string()),
121 });
122
123 let schema = TableSchema {
124 name: name.to_string(),
125 columns,
126 indexes: vec![],
127 };
128
129 let registration = ComponentRegistration {
130 type_id,
131 name: name.to_string(),
132 schema,
133 };
134
135 self.component_names.insert(name.to_string(), type_id);
136 self.registrations.insert(type_id, registration);
137
138 Ok(())
139 }
140
141 pub fn get_schema(&self, type_id: &TypeId) -> Option<&TableSchema> {
143 self.registrations.get(type_id).map(|r| &r.schema)
144 }
145
146 pub fn get_schema_by_name(&self, name: &str) -> Option<&TableSchema> {
148 self.component_names
149 .get(name)
150 .and_then(|type_id| self.get_schema(type_id))
151 }
152
153 pub fn generate_ddl(&self) -> String {
155 let mut ddl = String::new();
156
157 for registration in self.registrations.values() {
158 ddl.push_str(&self.generate_table_ddl(®istration.schema));
159 ddl.push_str("\n\n");
160 }
161
162 ddl.trim().to_string()
163 }
164
165 pub fn generate_table_ddl(&self, schema: &TableSchema) -> String {
167 let mut sql = format!("CREATE TABLE IF NOT EXISTS {} (\n", schema.name);
168
169 let column_defs: Vec<String> = schema
170 .columns
171 .iter()
172 .map(|col| {
173 let mut def = format!(" {} {}", col.name, col.sql_type.to_sql());
174
175 if !col.nullable {
176 def.push_str(" NOT NULL");
177 }
178
179 if let Some(default) = &col.default {
180 def.push_str(&format!(" DEFAULT {}", default));
181 }
182
183 def
184 })
185 .collect();
186
187 sql.push_str(&column_defs.join(",\n"));
188 sql.push_str(",\n PRIMARY KEY (entity_id)\n");
189 sql.push_str(");");
190
191 for index in &schema.indexes {
193 sql.push_str("\n\n");
194 sql.push_str(&self.generate_index_ddl(&schema.name, index));
195 }
196
197 sql
198 }
199
200 pub fn generate_index_ddl(&self, table_name: &str, index: &IndexDef) -> String {
202 let unique = if index.unique { "UNIQUE " } else { "" };
203 format!(
204 "CREATE {}INDEX IF NOT EXISTS {} ON {} ({});",
205 unique,
206 index.name,
207 table_name,
208 index.columns.join(", ")
209 )
210 }
211
212 pub fn add_index(
214 &mut self,
215 type_id: &TypeId,
216 index_name: &str,
217 columns: Vec<String>,
218 unique: bool,
219 ) -> Result<()> {
220 let registration = self
221 .registrations
222 .get_mut(type_id)
223 .ok_or_else(|| QueryError::ComponentNotRegistered(format!("{:?}", type_id)))?;
224
225 registration.schema.indexes.push(IndexDef {
226 name: index_name.to_string(),
227 columns,
228 unique,
229 });
230
231 Ok(())
232 }
233
234 pub fn list_components(&self) -> Vec<String> {
236 self.component_names.keys().cloned().collect()
237 }
238}
239
240impl Default for SchemaGenerator {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[derive(Debug)]
251 struct TestComponent;
252
253 #[test]
254 fn test_schema_generation() {
255 let mut generator = SchemaGenerator::new();
256
257 generator
258 .register::<TestComponent>(
259 "Player",
260 vec![
261 ("name", SqlType::Text, false),
262 ("email", SqlType::Text, false),
263 ("score", SqlType::Integer, false),
264 ],
265 )
266 .unwrap();
267
268 let ddl = generator.generate_ddl();
269 assert!(ddl.contains("CREATE TABLE IF NOT EXISTS Player"));
270 assert!(ddl.contains("entity_id BIGINT NOT NULL"));
271 assert!(ddl.contains("name TEXT NOT NULL"));
272 assert!(ddl.contains("email TEXT NOT NULL"));
273 assert!(ddl.contains("score INTEGER NOT NULL"));
274 assert!(ddl.contains("_tx2_updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"));
275 assert!(ddl.contains("PRIMARY KEY (entity_id)"));
276 }
277
278 #[test]
279 fn test_index_generation() {
280 let mut generator = SchemaGenerator::new();
281
282 generator
283 .register::<TestComponent>(
284 "Player",
285 vec![("name", SqlType::Text, false)],
286 )
287 .unwrap();
288
289 let type_id = TypeId::of::<TestComponent>();
290 generator
291 .add_index(&type_id, "idx_player_name", vec!["name".to_string()], false)
292 .unwrap();
293
294 let schema = generator.get_schema(&type_id).unwrap();
295 let index_ddl = generator.generate_index_ddl(&schema.name, &schema.indexes[0]);
296
297 assert!(index_ddl.contains("CREATE INDEX IF NOT EXISTS idx_player_name"));
298 assert!(index_ddl.contains("ON Player (name)"));
299 }
300}