1use ast::HttpMethod;
2use axum::Json;
3use axum::response::IntoResponse;
4use axum::{response::Html, routing::*};
5use hyper::StatusCode;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use sqlx::postgres::PgPoolOptions;
8use sqlx::sqlite::SqlitePoolOptions;
9use sqlx::{PgPool, SqlitePool};
10use std::path::Path;
11use std::sync::Arc;
12use std::time::Duration;
13use std::{fs, net::SocketAddr, path::PathBuf};
14use tokio::net::TcpListener;
15use tokio::sync::broadcast;
16use walkdir::WalkDir;
17
18use crate::endpoint::{Endpoint, convert_field};
19pub use config::Config;
20mod ast;
21mod config;
22mod endpoint;
23mod error;
24mod openapi;
25mod parser;
26mod utils;
27
28use aiscript_lexer as lexer;
29
30#[derive(Debug, Clone)]
31struct ReloadSignal;
32
33fn read_routes() -> Vec<ast::Route> {
34 let mut routes = Vec::new();
35 for entry in WalkDir::new("routes")
36 .contents_first(true)
37 .into_iter()
38 .filter_entry(|e| {
39 e.file_type().is_file()
40 && e.file_name()
41 .to_str()
42 .map(|s| s.ends_with(".ai"))
43 .unwrap_or(false)
44 })
45 .filter_map(|e| e.ok())
46 {
47 let file_path = entry.path();
48 if let Some(route) = read_single_route(file_path) {
49 routes.push(route);
50 }
51 }
52 routes
53}
54
55fn read_single_route(file_path: &Path) -> Option<ast::Route> {
56 match fs::read_to_string(file_path) {
57 Ok(input) => match parser::parse_route(&input) {
58 Ok(route) => return Some(route),
59 Err(e) => eprintln!("Error parsing route file {:?}: {}", file_path, e),
60 },
61 Err(e) => eprintln!("Error reading route file {:?}: {}", file_path, e),
62 }
63
64 None
65}
66
67pub async fn run(path: Option<PathBuf>, port: u16, reload: bool) {
68 if !reload {
69 run_server(path, port, None).await;
71 return;
72 }
73
74 let (tx, _) = broadcast::channel::<ReloadSignal>(1);
76 let tx = Arc::new(tx);
77
78 let watcher_tx = tx.clone();
80 let mut watcher = setup_watcher(move |event| {
81 if let Some(path) = event.paths.first().and_then(|p| p.to_str()) {
83 if path.ends_with(".ai") {
84 watcher_tx.send(ReloadSignal).unwrap();
85 }
86 }
87 })
88 .expect("Failed to setup watcher");
89
90 watcher
92 .watch(Path::new("routes"), RecursiveMode::Recursive)
93 .expect("Failed to watch routes directory");
94
95 loop {
96 let mut rx = tx.subscribe();
97 let server_handle = tokio::spawn(run_server(path.clone(), port, Some(rx.resubscribe())));
98
99 match rx.recv().await {
101 Ok(_) => {
102 println!("📑 Routes changed, reloading server...");
103 tokio::time::sleep(Duration::from_millis(100)).await;
105 server_handle.abort();
106 }
107 Err(_) => {
108 break;
109 }
110 }
111 }
112}
113
114fn setup_watcher<F>(mut callback: F) -> notify::Result<RecommendedWatcher>
115where
116 F: FnMut(notify::Event) + Send + 'static,
117{
118 let watcher = notify::recommended_watcher(move |res: notify::Result<notify::Event>| {
119 match res {
120 Ok(event) => {
121 if event.kind.is_modify() || event.kind.is_create() || event.kind.is_remove() {
123 callback(event);
124 }
125 }
126 Err(e) => println!("Watch error: {:?}", e),
127 }
128 })?;
129 Ok(watcher)
130}
131
132pub async fn get_pg_connection() -> Option<PgPool> {
133 let config = Config::get();
134 match config.database.get_postgres_url() {
135 Some(url) => PgPoolOptions::new()
136 .max_connections(5)
137 .connect(&url)
138 .await
139 .ok(),
140 None => None,
141 }
142}
143
144pub async fn get_sqlite_connection() -> Option<SqlitePool> {
145 let config = Config::get();
146 match config.database.get_sqlite_url() {
147 Some(url) => SqlitePoolOptions::new()
148 .max_connections(5)
149 .connect(&url)
150 .await
151 .ok(),
152 None => None,
153 }
154}
155
156pub async fn get_redis_connection() -> Option<redis::aio::MultiplexedConnection> {
157 let config = Config::get();
158 match config.database.get_redis_url() {
159 Some(url) => {
160 let client = redis::Client::open(url).unwrap();
161 let conn = client.get_multiplexed_async_connection().await.unwrap();
162 Some(conn)
163 }
164 None => None,
165 }
166}
167
168async fn run_server(
169 path: Option<PathBuf>,
170 port: u16,
171 reload_rx: Option<broadcast::Receiver<ReloadSignal>>,
172) {
173 let config = Config::get();
174
175 let routes = if let Some(file_path) = path {
176 read_single_route(&file_path).into_iter().collect()
177 } else {
178 read_routes()
179 };
180
181 if routes.is_empty() {
182 eprintln!("Warning: No valid routes found!");
183 return;
184 }
185
186 let mut router = Router::new();
187 let openapi = openapi::OpenAPIGenerator::generate(&routes);
188 router = router.route("/openapi.json", get(move || async { Json(openapi) }));
189
190 if config.apidoc.enabled {
191 match config.apidoc.doc_type {
192 config::ApiDocType::Swagger => {
193 }
198 config::ApiDocType::Redoc => {
199 router = router.route(
200 &config.apidoc.path,
201 get(|| async { Html(include_str!("openapi/redoc.html")) }),
202 );
203 }
204 }
205 }
206
207 let pg_connection = get_pg_connection().await;
208 let sqlite_connection = get_sqlite_connection().await;
209 let redis_connection = get_redis_connection().await;
210 for route in routes {
211 let mut r = Router::new();
212 for endpoint_spec in route.endpoints {
213 let endpoint = Endpoint {
214 annotation: endpoint_spec.annotation.or(&route.annotation),
215 path_params: endpoint_spec.path.into_iter().map(convert_field).collect(),
216 query_params: endpoint_spec.query.into_iter().map(convert_field).collect(),
217 body_type: endpoint_spec.body.kind,
218 body_fields: endpoint_spec
219 .body
220 .fields
221 .into_iter()
222 .map(convert_field)
223 .collect(),
224 script: endpoint_spec.statements,
225 path_specs: endpoint_spec.path_specs,
226 pg_connection: pg_connection.as_ref().cloned(),
227 sqlite_connection: sqlite_connection.as_ref().cloned(),
228 redis_connection: redis_connection.as_ref().cloned(),
229 };
230
231 for path_spec in &endpoint.path_specs[..endpoint.path_specs.len() - 1] {
232 let service_fn = match path_spec.method {
233 HttpMethod::Get => get_service,
234 HttpMethod::Post => post_service,
235 HttpMethod::Put => put_service,
236 HttpMethod::Delete => delete_service,
237 };
238 r = r.route(&path_spec.path, service_fn(endpoint.clone()));
239 }
240
241 let last_path_specs = &endpoint.path_specs[endpoint.path_specs.len() - 1];
243 let service_fn = match last_path_specs.method {
244 HttpMethod::Get => get_service,
245 HttpMethod::Post => post_service,
246 HttpMethod::Put => put_service,
247 HttpMethod::Delete => delete_service,
248 };
249 r = r.route(&last_path_specs.path.clone(), service_fn(endpoint));
250 }
251
252 if route.prefix == "/" {
253 router = router.merge(r);
255 } else {
256 router = router.nest(&route.prefix, r);
257 }
258 }
259
260 async fn handle_404() -> impl IntoResponse {
262 let error_json = serde_json::json!({
263 "message": "Not Found"
264 });
265
266 (StatusCode::NOT_FOUND, Json(error_json))
267 }
268
269 router = router.fallback(handle_404);
271
272 let addr = SocketAddr::from(([0, 0, 0, 0], port));
273 let listener = TcpListener::bind(addr).await.unwrap();
274
275 match reload_rx {
276 Some(mut rx) => {
277 let (close_tx, close_rx) = tokio::sync::oneshot::channel();
279
280 let shutdown_task = tokio::spawn(async move {
282 if rx.recv().await.is_ok() {
283 close_tx.send(()).unwrap();
284 }
285 });
286
287 axum::serve(listener, router)
289 .with_graceful_shutdown(async {
290 let _ = close_rx.await;
291 })
292 .await
293 .unwrap();
294
295 shutdown_task.abort();
296 }
297 None => {
298 axum::serve(listener, router).await.unwrap();
300 }
301 }
302}