aiscript_runtime/
lib.rs

1use ast::HttpMethod;
2use axum::Json;
3use axum::{response::Html, routing::*};
4use notify::{RecommendedWatcher, RecursiveMode, Watcher};
5use sqlx::postgres::PgPoolOptions;
6use sqlx::sqlite::SqlitePoolOptions;
7use sqlx::{PgPool, SqlitePool};
8use std::path::Path;
9use std::sync::Arc;
10use std::time::Duration;
11use std::{fs, net::SocketAddr, path::PathBuf};
12use tokio::net::TcpListener;
13use tokio::sync::broadcast;
14use walkdir::WalkDir;
15
16use crate::endpoint::{Endpoint, convert_field};
17pub use config::Config;
18mod ast;
19mod config;
20mod endpoint;
21mod error;
22mod openapi;
23mod parser;
24mod utils;
25
26use aiscript_lexer as lexer;
27
28#[derive(Debug, Clone)]
29struct ReloadSignal;
30
31fn read_routes() -> Vec<ast::Route> {
32    let mut routes = Vec::new();
33    for entry in WalkDir::new("routes")
34        .contents_first(true)
35        .into_iter()
36        .filter_entry(|e| {
37            e.file_type().is_file()
38                && e.file_name()
39                    .to_str()
40                    .map(|s| s.ends_with(".ai"))
41                    .unwrap_or(false)
42        })
43        .filter_map(|e| e.ok())
44    {
45        let file_path = entry.path();
46        if let Some(route) = read_single_route(file_path) {
47            routes.push(route);
48        }
49    }
50    routes
51}
52
53fn read_single_route(file_path: &Path) -> Option<ast::Route> {
54    match fs::read_to_string(file_path) {
55        Ok(input) => match parser::parse_route(&input) {
56            Ok(route) => return Some(route),
57            Err(e) => eprintln!("Error parsing route file {:?}: {}", file_path, e),
58        },
59        Err(e) => eprintln!("Error reading route file {:?}: {}", file_path, e),
60    }
61
62    None
63}
64
65pub async fn run(path: Option<PathBuf>, port: u16, reload: bool) {
66    if !reload {
67        // Run without reload functionality
68        run_server(path, port, None).await;
69        return;
70    }
71
72    // Create a channel for reload coordination
73    let (tx, _) = broadcast::channel::<ReloadSignal>(1);
74    let tx = Arc::new(tx);
75
76    // Set up file watcher
77    let watcher_tx = tx.clone();
78    let mut watcher = setup_watcher(move |event| {
79        // Only trigger reload for .ai files
80        if let Some(path) = event.paths.first().and_then(|p| p.to_str()) {
81            if path.ends_with(".ai") {
82                watcher_tx.send(ReloadSignal).unwrap();
83            }
84        }
85    })
86    .expect("Failed to setup watcher");
87
88    // Watch the routes directory
89    watcher
90        .watch(Path::new("routes"), RecursiveMode::Recursive)
91        .expect("Failed to watch routes directory");
92
93    loop {
94        let mut rx = tx.subscribe();
95        let server_handle = tokio::spawn(run_server(path.clone(), port, Some(rx.resubscribe())));
96
97        // Wait for reload signal
98        match rx.recv().await {
99            Ok(_) => {
100                println!("📑 Routes changed, reloading server...");
101                // Give some time for pending requests to complete
102                tokio::time::sleep(Duration::from_millis(100)).await;
103                server_handle.abort();
104            }
105            Err(_) => {
106                break;
107            }
108        }
109    }
110}
111
112fn setup_watcher<F>(mut callback: F) -> notify::Result<RecommendedWatcher>
113where
114    F: FnMut(notify::Event) + Send + 'static,
115{
116    let watcher = notify::recommended_watcher(move |res: notify::Result<notify::Event>| {
117        match res {
118            Ok(event) => {
119                // Only trigger on write/create/remove events
120                if event.kind.is_modify() || event.kind.is_create() || event.kind.is_remove() {
121                    callback(event);
122                }
123            }
124            Err(e) => println!("Watch error: {:?}", e),
125        }
126    })?;
127    Ok(watcher)
128}
129
130pub async fn get_pg_connection() -> Option<PgPool> {
131    let config = Config::get();
132    match config.database.get_postgres_url() {
133        Some(url) => PgPoolOptions::new()
134            .max_connections(5)
135            .connect(&url)
136            .await
137            .ok(),
138        None => None,
139    }
140}
141
142pub async fn get_sqlite_connection() -> Option<SqlitePool> {
143    let config = Config::get();
144    match config.database.get_sqlite_url() {
145        Some(url) => SqlitePoolOptions::new()
146            .max_connections(5)
147            .connect(&url)
148            .await
149            .ok(),
150        None => None,
151    }
152}
153
154pub async fn get_redis_connection() -> Option<redis::aio::MultiplexedConnection> {
155    let config = Config::get();
156    match config.database.get_redis_url() {
157        Some(url) => {
158            let client = redis::Client::open(url).unwrap();
159            let conn = client.get_multiplexed_async_connection().await.unwrap();
160            Some(conn)
161        }
162        None => None,
163    }
164}
165
166async fn run_server(
167    path: Option<PathBuf>,
168    port: u16,
169    reload_rx: Option<broadcast::Receiver<ReloadSignal>>,
170) {
171    let config = Config::get();
172
173    let routes = if let Some(file_path) = path {
174        read_single_route(&file_path).into_iter().collect()
175    } else {
176        read_routes()
177    };
178
179    if routes.is_empty() {
180        eprintln!("Warning: No valid routes found!");
181        return;
182    }
183
184    let mut router = Router::new();
185    let openapi = openapi::OpenAPIGenerator::generate(&routes);
186    router = router.route("/openapi.json", get(move || async { Json(openapi) }));
187
188    if config.apidoc.enabled {
189        match config.apidoc.doc_type {
190            config::ApiDocType::Swagger => {
191                // router = router.route(
192                //     &config.apidoc.path,
193                //     get(move || async { Html(include_str!("openapi/swagger.html")) }),
194                // );
195            }
196            config::ApiDocType::Redoc => {
197                router = router.route(
198                    &config.apidoc.path,
199                    get(|| async { Html(include_str!("openapi/redoc.html")) }),
200                );
201            }
202        }
203    }
204
205    let pg_connection = get_pg_connection().await;
206    let sqlite_connection = get_sqlite_connection().await;
207    let redis_connection = get_redis_connection().await;
208    for route in routes {
209        let mut r = Router::new();
210        for endpoint_spec in route.endpoints {
211            let endpoint = Endpoint {
212                annotation: endpoint_spec.annotation.or(route.annotation),
213                query_params: endpoint_spec.query.into_iter().map(convert_field).collect(),
214                body_type: endpoint_spec.body.kind,
215                body_fields: endpoint_spec
216                    .body
217                    .fields
218                    .into_iter()
219                    .map(convert_field)
220                    .collect(),
221                script: endpoint_spec.statements,
222                path_specs: endpoint_spec.path_specs,
223                pg_connection: pg_connection.as_ref().cloned(),
224                sqlite_connection: sqlite_connection.as_ref().cloned(),
225                redis_connection: redis_connection.as_ref().cloned(),
226            };
227
228            for path_spec in &endpoint.path_specs[..endpoint.path_specs.len() - 1] {
229                let service_fn = match path_spec.method {
230                    HttpMethod::Get => get_service,
231                    HttpMethod::Post => post_service,
232                    HttpMethod::Put => put_service,
233                    HttpMethod::Delete => delete_service,
234                };
235                r = r.route(&path_spec.path, service_fn(endpoint.clone()));
236            }
237
238            // avoid clone the last one
239            let last_path_specs = &endpoint.path_specs[endpoint.path_specs.len() - 1];
240            let service_fn = match last_path_specs.method {
241                HttpMethod::Get => get_service,
242                HttpMethod::Post => post_service,
243                HttpMethod::Put => put_service,
244                HttpMethod::Delete => delete_service,
245            };
246            r = r.route(&last_path_specs.path.clone(), service_fn(endpoint));
247        }
248
249        if route.prefix == "/" {
250            // axum don't allow use nest() with root path
251            router = router.merge(r);
252        } else {
253            router = router.nest(&route.prefix, r);
254        }
255    }
256
257    let addr = SocketAddr::from(([0, 0, 0, 0], port));
258    let listener = TcpListener::bind(addr).await.unwrap();
259
260    match reload_rx {
261        Some(mut rx) => {
262            // Create a shutdown signal for reload case
263            let (close_tx, close_rx) = tokio::sync::oneshot::channel();
264
265            // Handle reload messages
266            let shutdown_task = tokio::spawn(async move {
267                if rx.recv().await.is_ok() {
268                    close_tx.send(()).unwrap();
269                }
270            });
271
272            // Run the server with graceful shutdown
273            axum::serve(listener, router)
274                .with_graceful_shutdown(async {
275                    let _ = close_rx.await;
276                })
277                .await
278                .unwrap();
279
280            shutdown_task.abort();
281        }
282        None => {
283            // Run without reload capability
284            axum::serve(listener, router).await.unwrap();
285        }
286    }
287}