1use ast::HttpMethod;
2use axum::Json;
3use axum::{response::Html, routing::*};
4use notify::{RecommendedWatcher, RecursiveMode, Watcher};
5use sqlx::postgres::PgPoolOptions;
6use sqlx::sqlite::SqlitePoolOptions;
7use sqlx::{PgPool, SqlitePool};
8use std::path::Path;
9use std::sync::Arc;
10use std::time::Duration;
11use std::{fs, net::SocketAddr, path::PathBuf};
12use tokio::net::TcpListener;
13use tokio::sync::broadcast;
14use walkdir::WalkDir;
15
16use crate::endpoint::{Endpoint, convert_field};
17pub use config::Config;
18mod ast;
19mod config;
20mod endpoint;
21mod error;
22mod openapi;
23mod parser;
24mod utils;
25
26use aiscript_lexer as lexer;
27
28#[derive(Debug, Clone)]
29struct ReloadSignal;
30
31fn read_routes() -> Vec<ast::Route> {
32 let mut routes = Vec::new();
33 for entry in WalkDir::new("routes")
34 .contents_first(true)
35 .into_iter()
36 .filter_entry(|e| {
37 e.file_type().is_file()
38 && e.file_name()
39 .to_str()
40 .map(|s| s.ends_with(".ai"))
41 .unwrap_or(false)
42 })
43 .filter_map(|e| e.ok())
44 {
45 let file_path = entry.path();
46 if let Some(route) = read_single_route(file_path) {
47 routes.push(route);
48 }
49 }
50 routes
51}
52
53fn read_single_route(file_path: &Path) -> Option<ast::Route> {
54 match fs::read_to_string(file_path) {
55 Ok(input) => match parser::parse_route(&input) {
56 Ok(route) => return Some(route),
57 Err(e) => eprintln!("Error parsing route file {:?}: {}", file_path, e),
58 },
59 Err(e) => eprintln!("Error reading route file {:?}: {}", file_path, e),
60 }
61
62 None
63}
64
65pub async fn run(path: Option<PathBuf>, port: u16, reload: bool) {
66 if !reload {
67 run_server(path, port, None).await;
69 return;
70 }
71
72 let (tx, _) = broadcast::channel::<ReloadSignal>(1);
74 let tx = Arc::new(tx);
75
76 let watcher_tx = tx.clone();
78 let mut watcher = setup_watcher(move |event| {
79 if let Some(path) = event.paths.first().and_then(|p| p.to_str()) {
81 if path.ends_with(".ai") {
82 watcher_tx.send(ReloadSignal).unwrap();
83 }
84 }
85 })
86 .expect("Failed to setup watcher");
87
88 watcher
90 .watch(Path::new("routes"), RecursiveMode::Recursive)
91 .expect("Failed to watch routes directory");
92
93 loop {
94 let mut rx = tx.subscribe();
95 let server_handle = tokio::spawn(run_server(path.clone(), port, Some(rx.resubscribe())));
96
97 match rx.recv().await {
99 Ok(_) => {
100 println!("📑 Routes changed, reloading server...");
101 tokio::time::sleep(Duration::from_millis(100)).await;
103 server_handle.abort();
104 }
105 Err(_) => {
106 break;
107 }
108 }
109 }
110}
111
112fn setup_watcher<F>(mut callback: F) -> notify::Result<RecommendedWatcher>
113where
114 F: FnMut(notify::Event) + Send + 'static,
115{
116 let watcher = notify::recommended_watcher(move |res: notify::Result<notify::Event>| {
117 match res {
118 Ok(event) => {
119 if event.kind.is_modify() || event.kind.is_create() || event.kind.is_remove() {
121 callback(event);
122 }
123 }
124 Err(e) => println!("Watch error: {:?}", e),
125 }
126 })?;
127 Ok(watcher)
128}
129
130pub async fn get_pg_connection() -> Option<PgPool> {
131 let config = Config::get();
132 match config.database.get_postgres_url() {
133 Some(url) => PgPoolOptions::new()
134 .max_connections(5)
135 .connect(&url)
136 .await
137 .ok(),
138 None => None,
139 }
140}
141
142pub async fn get_sqlite_connection() -> Option<SqlitePool> {
143 let config = Config::get();
144 match config.database.get_sqlite_url() {
145 Some(url) => SqlitePoolOptions::new()
146 .max_connections(5)
147 .connect(&url)
148 .await
149 .ok(),
150 None => None,
151 }
152}
153
154pub async fn get_redis_connection() -> Option<redis::aio::MultiplexedConnection> {
155 let config = Config::get();
156 match config.database.get_redis_url() {
157 Some(url) => {
158 let client = redis::Client::open(url).unwrap();
159 let conn = client.get_multiplexed_async_connection().await.unwrap();
160 Some(conn)
161 }
162 None => None,
163 }
164}
165
166async fn run_server(
167 path: Option<PathBuf>,
168 port: u16,
169 reload_rx: Option<broadcast::Receiver<ReloadSignal>>,
170) {
171 let config = Config::get();
172
173 let routes = if let Some(file_path) = path {
174 read_single_route(&file_path).into_iter().collect()
175 } else {
176 read_routes()
177 };
178
179 if routes.is_empty() {
180 eprintln!("Warning: No valid routes found!");
181 return;
182 }
183
184 let mut router = Router::new();
185 let openapi = openapi::OpenAPIGenerator::generate(&routes);
186 router = router.route("/openapi.json", get(move || async { Json(openapi) }));
187
188 if config.apidoc.enabled {
189 match config.apidoc.doc_type {
190 config::ApiDocType::Swagger => {
191 }
196 config::ApiDocType::Redoc => {
197 router = router.route(
198 &config.apidoc.path,
199 get(|| async { Html(include_str!("openapi/redoc.html")) }),
200 );
201 }
202 }
203 }
204
205 let pg_connection = get_pg_connection().await;
206 let sqlite_connection = get_sqlite_connection().await;
207 let redis_connection = get_redis_connection().await;
208 for route in routes {
209 let mut r = Router::new();
210 for endpoint_spec in route.endpoints {
211 let endpoint = Endpoint {
212 annotation: endpoint_spec.annotation.or(route.annotation),
213 query_params: endpoint_spec.query.into_iter().map(convert_field).collect(),
214 body_type: endpoint_spec.body.kind,
215 body_fields: endpoint_spec
216 .body
217 .fields
218 .into_iter()
219 .map(convert_field)
220 .collect(),
221 script: endpoint_spec.statements,
222 path_specs: endpoint_spec.path_specs,
223 pg_connection: pg_connection.as_ref().cloned(),
224 sqlite_connection: sqlite_connection.as_ref().cloned(),
225 redis_connection: redis_connection.as_ref().cloned(),
226 };
227
228 for path_spec in &endpoint.path_specs[..endpoint.path_specs.len() - 1] {
229 let service_fn = match path_spec.method {
230 HttpMethod::Get => get_service,
231 HttpMethod::Post => post_service,
232 HttpMethod::Put => put_service,
233 HttpMethod::Delete => delete_service,
234 };
235 r = r.route(&path_spec.path, service_fn(endpoint.clone()));
236 }
237
238 let last_path_specs = &endpoint.path_specs[endpoint.path_specs.len() - 1];
240 let service_fn = match last_path_specs.method {
241 HttpMethod::Get => get_service,
242 HttpMethod::Post => post_service,
243 HttpMethod::Put => put_service,
244 HttpMethod::Delete => delete_service,
245 };
246 r = r.route(&last_path_specs.path.clone(), service_fn(endpoint));
247 }
248
249 if route.prefix == "/" {
250 router = router.merge(r);
252 } else {
253 router = router.nest(&route.prefix, r);
254 }
255 }
256
257 let addr = SocketAddr::from(([0, 0, 0, 0], port));
258 let listener = TcpListener::bind(addr).await.unwrap();
259
260 match reload_rx {
261 Some(mut rx) => {
262 let (close_tx, close_rx) = tokio::sync::oneshot::channel();
264
265 let shutdown_task = tokio::spawn(async move {
267 if rx.recv().await.is_ok() {
268 close_tx.send(()).unwrap();
269 }
270 });
271
272 axum::serve(listener, router)
274 .with_graceful_shutdown(async {
275 let _ = close_rx.await;
276 })
277 .await
278 .unwrap();
279
280 shutdown_task.abort();
281 }
282 None => {
283 axum::serve(listener, router).await.unwrap();
285 }
286 }
287}