aiscript_runtime/
lib.rs

1use ast::HttpMethod;
2use axum::Json;
3use axum::response::IntoResponse;
4use axum::{response::Html, routing::*};
5use hyper::StatusCode;
6use notify::{RecommendedWatcher, RecursiveMode, Watcher};
7use sqlx::postgres::PgPoolOptions;
8use sqlx::sqlite::SqlitePoolOptions;
9use sqlx::{PgPool, SqlitePool};
10use std::path::Path;
11use std::sync::Arc;
12use std::time::Duration;
13use std::{fs, net::SocketAddr, path::PathBuf};
14use tokio::net::TcpListener;
15use tokio::sync::broadcast;
16use walkdir::WalkDir;
17
18use crate::endpoint::{Endpoint, convert_field};
19pub use config::Config;
20mod ast;
21mod config;
22mod endpoint;
23mod error;
24mod openapi;
25mod parser;
26mod utils;
27
28use aiscript_lexer as lexer;
29
30#[derive(Debug, Clone)]
31struct ReloadSignal;
32
33fn read_routes() -> Vec<ast::Route> {
34    let mut routes = Vec::new();
35    for entry in WalkDir::new("routes")
36        .contents_first(true)
37        .into_iter()
38        .filter_entry(|e| {
39            e.file_type().is_file()
40                && e.file_name()
41                    .to_str()
42                    .map(|s| s.ends_with(".ai"))
43                    .unwrap_or(false)
44        })
45        .filter_map(|e| e.ok())
46    {
47        let file_path = entry.path();
48        if let Some(route) = read_single_route(file_path) {
49            routes.push(route);
50        }
51    }
52    routes
53}
54
55fn read_single_route(file_path: &Path) -> Option<ast::Route> {
56    match fs::read_to_string(file_path) {
57        Ok(input) => match parser::parse_route(&input) {
58            Ok(route) => return Some(route),
59            Err(e) => eprintln!("Error parsing route file {:?}: {}", file_path, e),
60        },
61        Err(e) => eprintln!("Error reading route file {:?}: {}", file_path, e),
62    }
63
64    None
65}
66
67pub async fn run(path: Option<PathBuf>, port: u16, reload: bool) {
68    if !reload {
69        // Run without reload functionality
70        run_server(path, port, None).await;
71        return;
72    }
73
74    // Create a channel for reload coordination
75    let (tx, _) = broadcast::channel::<ReloadSignal>(1);
76    let tx = Arc::new(tx);
77
78    // Set up file watcher
79    let watcher_tx = tx.clone();
80    let mut watcher = setup_watcher(move |event| {
81        // Only trigger reload for .ai files
82        if let Some(path) = event.paths.first().and_then(|p| p.to_str()) {
83            if path.ends_with(".ai") {
84                watcher_tx.send(ReloadSignal).unwrap();
85            }
86        }
87    })
88    .expect("Failed to setup watcher");
89
90    // Watch the routes directory
91    watcher
92        .watch(Path::new("routes"), RecursiveMode::Recursive)
93        .expect("Failed to watch routes directory");
94
95    loop {
96        let mut rx = tx.subscribe();
97        let server_handle = tokio::spawn(run_server(path.clone(), port, Some(rx.resubscribe())));
98
99        // Wait for reload signal
100        match rx.recv().await {
101            Ok(_) => {
102                println!("📑 Routes changed, reloading server...");
103                // Give some time for pending requests to complete
104                tokio::time::sleep(Duration::from_millis(100)).await;
105                server_handle.abort();
106            }
107            Err(_) => {
108                break;
109            }
110        }
111    }
112}
113
114fn setup_watcher<F>(mut callback: F) -> notify::Result<RecommendedWatcher>
115where
116    F: FnMut(notify::Event) + Send + 'static,
117{
118    let watcher = notify::recommended_watcher(move |res: notify::Result<notify::Event>| {
119        match res {
120            Ok(event) => {
121                // Only trigger on write/create/remove events
122                if event.kind.is_modify() || event.kind.is_create() || event.kind.is_remove() {
123                    callback(event);
124                }
125            }
126            Err(e) => println!("Watch error: {:?}", e),
127        }
128    })?;
129    Ok(watcher)
130}
131
132pub async fn get_pg_connection() -> Option<PgPool> {
133    let config = Config::get();
134    match config.database.get_postgres_url() {
135        Some(url) => PgPoolOptions::new()
136            .max_connections(5)
137            .connect(&url)
138            .await
139            .ok(),
140        None => None,
141    }
142}
143
144pub async fn get_sqlite_connection() -> Option<SqlitePool> {
145    let config = Config::get();
146    match config.database.get_sqlite_url() {
147        Some(url) => SqlitePoolOptions::new()
148            .max_connections(5)
149            .connect(&url)
150            .await
151            .ok(),
152        None => None,
153    }
154}
155
156pub async fn get_redis_connection() -> Option<redis::aio::MultiplexedConnection> {
157    let config = Config::get();
158    match config.database.get_redis_url() {
159        Some(url) => {
160            let client = redis::Client::open(url).unwrap();
161            let conn = client.get_multiplexed_async_connection().await.unwrap();
162            Some(conn)
163        }
164        None => None,
165    }
166}
167
168async fn run_server(
169    path: Option<PathBuf>,
170    port: u16,
171    reload_rx: Option<broadcast::Receiver<ReloadSignal>>,
172) {
173    let config = Config::get();
174
175    let routes = if let Some(file_path) = path {
176        read_single_route(&file_path).into_iter().collect()
177    } else {
178        read_routes()
179    };
180
181    if routes.is_empty() {
182        eprintln!("Warning: No valid routes found!");
183        return;
184    }
185
186    let mut router = Router::new();
187    let openapi = openapi::OpenAPIGenerator::generate(&routes);
188    router = router.route("/openapi.json", get(move || async { Json(openapi) }));
189
190    if config.apidoc.enabled {
191        match config.apidoc.doc_type {
192            config::ApiDocType::Swagger => {
193                // router = router.route(
194                //     &config.apidoc.path,
195                //     get(move || async { Html(include_str!("openapi/swagger.html")) }),
196                // );
197            }
198            config::ApiDocType::Redoc => {
199                router = router.route(
200                    &config.apidoc.path,
201                    get(|| async { Html(include_str!("openapi/redoc.html")) }),
202                );
203            }
204        }
205    }
206
207    let pg_connection = get_pg_connection().await;
208    let sqlite_connection = get_sqlite_connection().await;
209    let redis_connection = get_redis_connection().await;
210    for route in routes {
211        let mut r = Router::new();
212        for endpoint_spec in route.endpoints {
213            let endpoint = Endpoint {
214                annotation: endpoint_spec.annotation.or(&route.annotation),
215                path_params: endpoint_spec.path.into_iter().map(convert_field).collect(),
216                query_params: endpoint_spec.query.into_iter().map(convert_field).collect(),
217                body_type: endpoint_spec.body.kind,
218                body_fields: endpoint_spec
219                    .body
220                    .fields
221                    .into_iter()
222                    .map(convert_field)
223                    .collect(),
224                script: endpoint_spec.statements,
225                path_specs: endpoint_spec.path_specs,
226                pg_connection: pg_connection.as_ref().cloned(),
227                sqlite_connection: sqlite_connection.as_ref().cloned(),
228                redis_connection: redis_connection.as_ref().cloned(),
229            };
230
231            for path_spec in &endpoint.path_specs[..endpoint.path_specs.len() - 1] {
232                let service_fn = match path_spec.method {
233                    HttpMethod::Get => get_service,
234                    HttpMethod::Post => post_service,
235                    HttpMethod::Put => put_service,
236                    HttpMethod::Delete => delete_service,
237                };
238                r = r.route(&path_spec.path, service_fn(endpoint.clone()));
239            }
240
241            // avoid clone the last one
242            let last_path_specs = &endpoint.path_specs[endpoint.path_specs.len() - 1];
243            let service_fn = match last_path_specs.method {
244                HttpMethod::Get => get_service,
245                HttpMethod::Post => post_service,
246                HttpMethod::Put => put_service,
247                HttpMethod::Delete => delete_service,
248            };
249            r = r.route(&last_path_specs.path.clone(), service_fn(endpoint));
250        }
251
252        if route.prefix == "/" {
253            // axum don't allow use nest() with root path
254            router = router.merge(r);
255        } else {
256            router = router.nest(&route.prefix, r);
257        }
258    }
259
260    // Add a custom 404 handler for unmatched routes
261    async fn handle_404() -> impl IntoResponse {
262        let error_json = serde_json::json!({
263            "message": "Not Found"
264        });
265
266        (StatusCode::NOT_FOUND, Json(error_json))
267    }
268
269    // Add the fallback handler to the router
270    router = router.fallback(handle_404);
271
272    let addr = SocketAddr::from(([0, 0, 0, 0], port));
273    let listener = TcpListener::bind(addr).await.unwrap();
274
275    match reload_rx {
276        Some(mut rx) => {
277            // Create a shutdown signal for reload case
278            let (close_tx, close_rx) = tokio::sync::oneshot::channel();
279
280            // Handle reload messages
281            let shutdown_task = tokio::spawn(async move {
282                if rx.recv().await.is_ok() {
283                    close_tx.send(()).unwrap();
284                }
285            });
286
287            // Run the server with graceful shutdown
288            axum::serve(listener, router)
289                .with_graceful_shutdown(async {
290                    let _ = close_rx.await;
291                })
292                .await
293                .unwrap();
294
295            shutdown_task.abort();
296        }
297        None => {
298            // Run without reload capability
299            axum::serve(listener, router).await.unwrap();
300        }
301    }
302}