git_auth/
db.rs

1use crate::{Login, Request, error::DatabaseError};
2use rusqlite::{Connection, params};
3use std::{env, fs};
4
5pub fn open() -> Result<Connection, DatabaseError> {
6    let path = env::home_dir()
7        .ok_or(DatabaseError::Path)?
8        .join(".local/share/git-auth");
9
10    if !path.exists() {
11        fs::create_dir_all(&path)?;
12    }
13
14    let conn = Connection::open(path.join("creds.db"))?;
15
16    conn.execute("PRAGMA foreign_keys = ON", ())?;
17    conn.execute(
18        "CREATE TABLE IF NOT EXISTS logins (
19            id INTEGER PRIMARY KEY AUTOINCREMENT,
20            username TEXT NOT NULL,
21            email TEXT,
22            host TEXT NOT NULL,
23            CONSTRAINT username_unique UNIQUE (host, username),
24            CONSTRAINT email_unique UNIQUE (host, email)
25        )
26        ",
27        (),
28    )?;
29
30    conn.execute(
31        "CREATE TABLE IF NOT EXISTS requests (
32            id INTEGER PRIMARY KEY AUTOINCREMENT,
33            protocol TEXT NOT NULL,
34            host TEXT NOT NULL,
35            owner TEXT,
36            valid BOOLEAN NOT NULL DEFAULT 0,
37            user_id INTEGER,
38            FOREIGN KEY (user_id) REFERENCES logins (id)
39        )
40        ",
41        (),
42    )?;
43
44    Ok(conn)
45}
46
47pub fn add_login(conn: &Connection, login: &Login) -> rusqlite::Result<i64> {
48    conn.query_row(
49        "
50        SELECT id FROM logins
51        WHERE username = ?1
52          AND email = ?2
53          AND host = ?3
54        ",
55        params![login.username, login.email, login.host],
56        |row| row.get("id"),
57    )
58    .or_else(|_| {
59        conn.execute(
60            "INSERT INTO logins (username, email, host) VALUES (?1, ?2, ?3)",
61            params![login.username, login.email, login.host],
62        )?;
63        Ok(conn.last_insert_rowid())
64    })
65}
66
67pub fn validate_request(conn: &Connection, request: &Request, valid: bool) -> rusqlite::Result<()> {
68    conn.execute(
69        "
70        UPDATE requests
71        SET valid = ?1
72        WHERE host = ?2
73            AND owner = ?3
74            AND protocol = ?4
75        ",
76        params![valid, request.host, request.owner, request.protocol],
77    )?;
78    Ok(())
79}
80
81pub fn add_request(conn: &Connection, request: &Request, user_id: &i64) -> rusqlite::Result<i64> {
82    conn.execute(
83        "INSERT INTO requests (protocol, owner, host, user_id) VALUES (?1, ?2, ?3, ?4)",
84        params![request.protocol, request.owner, request.host, user_id],
85    )?;
86    Ok(conn.last_insert_rowid())
87}
88
89pub fn fetch_login(conn: &Connection, request: &Request) -> rusqlite::Result<(Login, bool)> {
90    conn.query_row(
91        "
92        SELECT l.username, l.email, r.valid
93        FROM requests r
94        JOIN logins l ON r.user_id = l.id
95        WHERE r.host = ?1
96          AND r.owner = ?2
97          AND r.protocol = ?3
98        ",
99        params![request.host, request.owner, request.protocol],
100        |row| {
101            Ok((
102                Login::new(
103                    row.get("username")?,
104                    request.host.clone(),
105                    row.get("email")?,
106                ),
107                row.get("valid")?,
108            ))
109        },
110    )
111}
112
113pub fn fetch_available_logins(
114    conn: &Connection,
115    request: &Request,
116) -> rusqlite::Result<Vec<Login>> {
117    let mut stmt = conn.prepare("SELECT username, email FROM logins WHERE host = ?1")?;
118    stmt.query_map(params![request.host], |row| {
119        Ok(Login::new(
120            row.get("username")?,
121            request.host.clone(),
122            row.get("email")?,
123        ))
124    })?
125    .collect()
126}
127
128pub fn fetch_all_logins(conn: &Connection) -> rusqlite::Result<Vec<Login>> {
129    let mut stmt = conn.prepare("SELECT username, email, host FROM logins")?;
130    stmt.query_map((), |row| {
131        Ok(Login::new(
132            row.get("username")?,
133            row.get("host")?,
134            row.get("email")?,
135        ))
136    })?
137    .collect()
138}