pgml_dashboard/
lib.rs

1#[macro_use]
2extern crate rocket;
3
4use std::collections::HashMap;
5
6use rocket::form::Form;
7use rocket::response::Redirect;
8use rocket::route::Route;
9
10use sqlx::{postgres::PgPoolOptions, PgPool};
11
12use parking_lot::Mutex;
13use sailfish::TemplateOnce;
14use std::sync::Arc;
15
16mod errors;
17mod forms;
18pub mod guards;
19mod models;
20mod responses;
21mod templates;
22
23use guards::Cluster;
24use responses::{BadRequest, ResponseOk};
25
26#[derive(Debug)]
27pub struct Clusters {
28    pools: Arc<Mutex<HashMap<i64, PgPool>>>,
29}
30
31impl Clusters {
32    pub fn add(&self, cluster_id: i64, database_url: &str) -> anyhow::Result<PgPool> {
33        let mut pools = self.pools.lock();
34
35        let pool = PgPoolOptions::new()
36            .max_connections(5)
37            .idle_timeout(std::time::Duration::from_millis(15_000))
38            .min_connections(0)
39            .connect_lazy(database_url)?;
40
41        pools.insert(cluster_id, pool.clone());
42
43        Ok(pool)
44    }
45
46    pub fn get(&self, cluster_id: i64) -> Option<PgPool> {
47        match self.pools.lock().get(&cluster_id) {
48            Some(pool) => Some(pool.clone()),
49            None => None,
50        }
51    }
52
53    pub fn delete(&self, cluster_id: i64) {
54        let _ = self.pools.lock().remove(&cluster_id);
55    }
56
57    pub fn new() -> Clusters {
58        Clusters {
59            pools: Arc::new(Mutex::new(HashMap::new())),
60        }
61    }
62}
63
64#[get("/")]
65pub async fn index() -> Redirect {
66    Redirect::to("/dashboard/notebooks")
67}
68
69#[get("/projects")]
70pub async fn project_index(cluster: Cluster) -> Result<ResponseOk, errors::Error> {
71    Ok(ResponseOk(
72        templates::Projects {
73            topic: "projects".to_string(),
74            projects: models::Project::all(cluster.pool()).await?,
75        }
76        .render_once()
77        .unwrap(),
78    ))
79}
80
81#[get("/projects/<id>")]
82pub async fn project_get(cluster: Cluster, id: i64) -> Result<ResponseOk, errors::Error> {
83    let project = models::Project::get_by_id(cluster.pool(), id).await?;
84    let models = models::Model::get_by_project_id(cluster.pool(), id).await?;
85
86    Ok(ResponseOk(
87        templates::Project {
88            topic: "projects".to_string(),
89            project,
90            models,
91        }
92        .render_once()
93        .unwrap(),
94    ))
95}
96
97#[get("/notebooks")]
98pub async fn notebook_index(cluster: Cluster) -> Result<ResponseOk, errors::Error> {
99    Ok(ResponseOk(
100        templates::Notebooks {
101            topic: "notebooks".to_string(),
102            notebooks: models::Notebook::all(cluster.pool()).await?,
103        }
104        .render_once()
105        .unwrap(),
106    ))
107}
108
109#[post("/notebooks", data = "<data>")]
110pub async fn notebook_create(
111    cluster: Cluster,
112    data: Form<forms::Notebook<'_>>,
113) -> Result<Redirect, errors::Error> {
114    let notebook = crate::models::Notebook::create(cluster.pool(), data.name).await?;
115
116    Ok(Redirect::to(format!(
117        "/dashboard/notebooks/{}/",
118        notebook.id
119    )))
120}
121
122#[get("/notebooks/<id>")]
123pub async fn notebook_get(cluster: Cluster, id: i64) -> Result<ResponseOk, errors::Error> {
124    let notebook = models::Notebook::get_by_id(cluster.pool(), id).await?;
125
126    Ok(ResponseOk(
127        templates::Notebook {
128            topic: "notebooks".to_string(),
129            cells: notebook.cells(cluster.pool()).await?,
130            notebook: notebook,
131        }
132        .render_once()
133        .unwrap(),
134    ))
135}
136
137#[post("/notebooks/<id>/reset")]
138pub async fn notebook_reset(cluster: Cluster, id: i64) -> Result<Redirect, errors::Error> {
139    let notebook = models::Notebook::get_by_id(cluster.pool(), id).await?;
140    notebook.reset(cluster.pool()).await?;
141
142    Ok(Redirect::to(format!("/dashboard/notebooks/{}", id)))
143}
144
145#[post("/notebooks/<id>/cell", data = "<cell>")]
146pub async fn cell_create(
147    cluster: Cluster,
148    id: i64,
149    cell: Form<forms::Cell<'_>>,
150) -> Result<Redirect, errors::Error> {
151    let notebook = models::Notebook::get_by_id(cluster.pool(), id).await?;
152    let mut cell = models::Cell::create(
153        cluster.pool(),
154        &notebook,
155        cell.cell_type.parse::<i32>()?,
156        cell.contents,
157    )
158    .await?;
159    let _ = cell.render(cluster.pool()).await?;
160
161    Ok(Redirect::to(format!("/dashboard/notebooks/{}/", id)))
162}
163
164#[get("/notebooks/<notebook_id>/cell/<cell_id>")]
165pub async fn cell_get(
166    cluster: Cluster,
167    notebook_id: i64,
168    cell_id: i64,
169) -> Result<ResponseOk, errors::Error> {
170    let notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
171    let cell = models::Cell::get_by_id(cluster.pool(), cell_id).await?;
172
173    let bust_cache = std::time::SystemTime::now()
174        .duration_since(std::time::SystemTime::UNIX_EPOCH)?
175        .as_millis()
176        .to_string();
177
178    Ok(ResponseOk(
179        templates::Cell {
180            cell,
181            notebook,
182            selected: false,
183            edit: false,
184            bust_cache,
185        }
186        .render_once()
187        .unwrap(),
188    ))
189}
190
191#[post("/notebooks/<notebook_id>/cell/<cell_id>/edit", data = "<data>")]
192pub async fn cell_edit(
193    cluster: Cluster,
194    notebook_id: i64,
195    cell_id: i64,
196    data: Form<forms::Cell<'_>>,
197) -> Result<ResponseOk, errors::Error> {
198    let notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
199    let mut cell = models::Cell::get_by_id(cluster.pool(), cell_id).await?;
200
201    cell.update(
202        cluster.pool(),
203        data.cell_type.parse::<i32>()?,
204        &data.contents,
205    )
206    .await?;
207    cell.render(cluster.pool()).await?;
208
209    let bust_cache = std::time::SystemTime::now()
210        .duration_since(std::time::SystemTime::UNIX_EPOCH)?
211        .as_millis()
212        .to_string();
213
214    Ok(ResponseOk(
215        templates::Cell {
216            cell,
217            notebook,
218            selected: false,
219            edit: false,
220            bust_cache,
221        }
222        .render_once()
223        .unwrap(),
224    ))
225}
226
227#[get("/notebooks/<notebook_id>/cell/<cell_id>/edit")]
228pub async fn cell_trigger_edit(
229    cluster: Cluster,
230    notebook_id: i64,
231    cell_id: i64,
232) -> Result<ResponseOk, errors::Error> {
233    let notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
234    let cell = models::Cell::get_by_id(cluster.pool(), cell_id).await?;
235    let bust_cache = std::time::SystemTime::now()
236        .duration_since(std::time::SystemTime::UNIX_EPOCH)?
237        .as_millis()
238        .to_string();
239
240    Ok(ResponseOk(
241        templates::Cell {
242            cell,
243            notebook,
244            selected: false,
245            edit: true,
246            bust_cache,
247        }
248        .render_once()
249        .unwrap(),
250    ))
251}
252
253#[post("/notebooks/<notebook_id>/cell/<cell_id>/play")]
254pub async fn cell_play(
255    cluster: Cluster,
256    notebook_id: i64,
257    cell_id: i64,
258) -> Result<ResponseOk, errors::Error> {
259    let notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
260    let mut cell = models::Cell::get_by_id(cluster.pool(), cell_id).await?;
261    cell.render(cluster.pool()).await?;
262    let bust_cache = std::time::SystemTime::now()
263        .duration_since(std::time::SystemTime::UNIX_EPOCH)?
264        .as_millis()
265        .to_string();
266
267    Ok(ResponseOk(
268        templates::Cell {
269            cell,
270            notebook,
271            selected: true,
272            edit: false,
273            bust_cache,
274        }
275        .render_once()
276        .unwrap(),
277    ))
278}
279
280#[post("/notebooks/<notebook_id>/cell/<cell_id>/remove")]
281pub async fn cell_remove(
282    cluster: Cluster,
283    notebook_id: i64,
284    cell_id: i64,
285) -> Result<ResponseOk, errors::Error> {
286    let notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
287    let cell = models::Cell::get_by_id(cluster.pool(), cell_id).await?;
288    let bust_cache = std::time::SystemTime::now()
289        .duration_since(std::time::SystemTime::UNIX_EPOCH)?
290        .as_millis()
291        .to_string();
292
293    Ok(ResponseOk(
294        templates::Undo {
295            notebook,
296            cell,
297            bust_cache,
298        }
299        .render_once()?,
300    ))
301}
302
303#[post("/notebooks/<notebook_id>/cell/<cell_id>/delete")]
304pub async fn cell_delete(
305    cluster: Cluster,
306    notebook_id: i64,
307    cell_id: i64,
308) -> Result<Redirect, errors::Error> {
309    let _notebook = models::Notebook::get_by_id(cluster.pool(), notebook_id).await?;
310    let cell = models::Cell::get_by_id(cluster.pool(), cell_id).await?;
311
312    let _ = cell.delete(cluster.pool()).await?;
313
314    Ok(Redirect::to(format!(
315        "/dashboard/notebooks/{}",
316        notebook_id
317    )))
318}
319
320#[get("/models")]
321pub async fn models_index(cluster: Cluster) -> Result<ResponseOk, errors::Error> {
322    let projects = models::Project::all(cluster.pool()).await?;
323    let mut models = HashMap::new();
324    // let mut max_scores = HashMap::new();
325    // let mut min_scores = HashMap::new();
326
327    for project in &projects {
328        let project_models = models::Model::get_by_project_id(cluster.pool(), project.id).await?;
329        // let mut key_metrics = project_models
330        //     .iter()
331        //     .map(|m| m.key_metric(project).unwrap_or(0.))
332        //     .collect::<Vec<f64>>();
333        // key_metrics.sort_by(|a, b| a.partial_cmp(b).unwrap());
334
335        // max_scores.insert(project.id, key_metrics.iter().last().unwrap_or(&0.).clone());
336        // min_scores.insert(project.id, key_metrics.iter().next().unwrap_or(&0.).clone());
337
338        models.insert(project.id, project_models);
339    }
340
341    Ok(ResponseOk(
342        templates::Models {
343            topic: "models".to_string(),
344            projects,
345            models,
346            // min_scores,
347            // max_scores,
348        }
349        .render_once()
350        .unwrap(),
351    ))
352}
353
354#[get("/models/<id>")]
355pub async fn models_get(cluster: Cluster, id: i64) -> Result<ResponseOk, errors::Error> {
356    let model = models::Model::get_by_id(cluster.pool(), id).await?;
357    let snapshot = models::Snapshot::get_by_id(cluster.pool(), model.snapshot_id).await?;
358    let project = models::Project::get_by_id(cluster.pool(), model.project_id).await?;
359
360    Ok(ResponseOk(
361        templates::Model {
362            topic: "models".to_string(),
363            deployed: model.deployed(cluster.pool()).await?,
364            model,
365            snapshot,
366            project,
367        }
368        .render_once()
369        .unwrap(),
370    ))
371}
372
373#[get("/snapshots")]
374pub async fn snapshots_index(cluster: Cluster) -> Result<ResponseOk, errors::Error> {
375    let snapshots = models::Snapshot::all(cluster.pool()).await?;
376    let mut table_sizes = HashMap::new();
377
378    for snapshot in &snapshots {
379        let table_size = snapshot.table_size(cluster.pool()).await?;
380        table_sizes.insert(snapshot.id, table_size);
381    }
382
383    Ok(ResponseOk(
384        templates::Snapshots {
385            topic: "snapshots".to_string(),
386            snapshots,
387            table_sizes,
388        }
389        .render_once()
390        .unwrap(),
391    ))
392}
393
394#[get("/snapshots/<id>")]
395pub async fn snapshots_get(cluster: Cluster, id: i64) -> Result<ResponseOk, errors::Error> {
396    let snapshot = models::Snapshot::get_by_id(cluster.pool(), id).await?;
397    let samples = snapshot.samples(cluster.pool(), 500).await?;
398    let models = snapshot.models(cluster.pool()).await?;
399    let mut projects = HashMap::new();
400
401    for model in &models {
402        projects.insert(model.project_id, model.project(cluster.pool()).await?);
403    }
404
405    Ok(ResponseOk(
406        templates::Snapshot {
407            topic: "snapshots".to_string(),
408            table_size: snapshot.table_size(cluster.pool()).await?,
409            snapshot,
410            models,
411            projects,
412            samples,
413        }
414        .render_once()
415        .unwrap(),
416    ))
417}
418
419#[get("/deployments")]
420pub async fn deployments_index(cluster: Cluster) -> Result<ResponseOk, errors::Error> {
421    let projects = models::Project::all(cluster.pool()).await?;
422    let mut deployments = HashMap::new();
423
424    for project in projects.iter() {
425        deployments.insert(
426            project.id,
427            models::Deployment::get_by_project_id(cluster.pool(), project.id).await?,
428        );
429    }
430
431    Ok(ResponseOk(
432        templates::Deployments {
433            topic: "deployments".to_string(),
434            projects,
435            deployments,
436        }
437        .render_once()
438        .unwrap(),
439    ))
440}
441
442#[get("/deployments/<id>")]
443pub async fn deployments_get(cluster: Cluster, id: i64) -> Result<ResponseOk, errors::Error> {
444    let deployment = models::Deployment::get_by_id(cluster.pool(), id).await?;
445    let project = models::Project::get_by_id(cluster.pool(), deployment.project_id).await?;
446    let model = models::Model::get_by_id(cluster.pool(), deployment.model_id).await?;
447
448    Ok(ResponseOk(
449        templates::Deployment {
450            topic: "deployments".to_string(),
451            project,
452            deployment,
453            model,
454        }
455        .render_once()
456        .unwrap(),
457    ))
458}
459
460#[get("/uploader")]
461pub async fn uploader_index() -> ResponseOk {
462    ResponseOk(
463        templates::Uploader {
464            topic: "uploader".to_string(),
465            error: None,
466        }
467        .render_once()
468        .unwrap(),
469    )
470}
471
472#[post("/uploader", data = "<form>")]
473pub async fn uploader_upload(
474    cluster: Cluster,
475    form: Form<forms::Upload<'_>>,
476) -> Result<Redirect, BadRequest> {
477    let mut uploaded_file = models::UploadedFile::create(cluster.pool()).await.unwrap();
478
479    match uploaded_file
480        .upload(cluster.pool(), form.file.path().unwrap(), form.has_header)
481        .await
482    {
483        Ok(()) => Ok(Redirect::to(format!(
484            "/dashboard/uploader/done?table_name={}",
485            uploaded_file.table_name()
486        ))),
487        Err(err) => Err(BadRequest(
488            templates::Uploader {
489                topic: "uploader".to_string(),
490                error: Some(err.to_string()),
491            }
492            .render_once()
493            .unwrap(),
494        )),
495    }
496}
497
498#[get("/uploader/done?<table_name>")]
499pub async fn uploaded_index(cluster: Cluster, table_name: &str) -> ResponseOk {
500    let sql = templates::Sql::new(
501        cluster.pool(),
502        &format!("SELECT * FROM {} LIMIT 10", table_name),
503    )
504    .await
505    .unwrap();
506    ResponseOk(
507        templates::Uploaded {
508            topic: "uploader".to_string(),
509            table_name: table_name.to_string(),
510            columns: sql.columns.clone(),
511            sql,
512        }
513        .render_once()
514        .unwrap(),
515    )
516}
517
518pub fn paths() -> Vec<Route> {
519    routes![
520        index,
521        notebook_index,
522        project_index,
523        project_get,
524        notebook_create,
525        notebook_get,
526        notebook_reset,
527        cell_create,
528        cell_get,
529        cell_trigger_edit,
530        cell_edit,
531        cell_play,
532        cell_remove,
533        cell_delete,
534        models_index,
535        models_get,
536        snapshots_index,
537        snapshots_get,
538        deployments_index,
539        deployments_get,
540        uploader_index,
541        uploader_upload,
542        uploaded_index,
543    ]
544}
545
546pub async fn migrate(pool: &PgPool) -> anyhow::Result<()> {
547    Ok(sqlx::migrate!("./migrations").run(pool).await?)
548}