1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use lunatic_sqlite_api::guest_api::sqlite_guest_bindings as bindings;
use lunatic_sqlite_api::wire_format::{BindKey, BindList, BindPair, SqliteRow};

use super::client::SqliteClient;
use super::error::{SqliteCode, SqliteError, SqliteErrorExt};
use super::value::Value;
use crate::host::call_host_alloc;

/// Trait for querying data and executing queries.
pub trait Query {
    /// Executes a query with no bindings.
    fn query(&self, query: &str) -> Vec<Vec<Value>>;
    /// Prepares a query with bindings.
    fn prepare_query(&self, query: &str) -> Statement;
    /// Executes a query, ignoring any results.
    fn execute(&self, query: &str) -> Result<(), SqliteError>;
}

impl Query for SqliteClient {
    fn query(&self, query: &str) -> Vec<Vec<Value>> {
        self.prepare_query(query).execute()
    }

    fn prepare_query(&self, query: &str) -> Statement {
        let id = unsafe { bindings::query_prepare(self.id(), query.as_ptr(), query.len() as u32) };
        Statement {
            id,
            bindings: BindList(vec![]),
        }
    }

    fn execute(&self, query: &str) -> Result<(), SqliteError> {
        unsafe {
            lunatic_sqlite_api::guest_api::sqlite_guest_bindings::execute(
                self.id(),
                query.as_ptr(),
                query.len() as u32,
            )
        }
        .into_sqlite_error()
    }
}

/// Prepared SQL statement.
pub struct Statement {
    id: u64,
    bindings: BindList,
}

impl Statement {
    /// Bind based on an incrementing index.
    pub fn bind(mut self, value: impl Into<Value>) -> Self {
        let next_idx = self
            .bindings
            .iter()
            .rev()
            .find_map(|binding| match binding {
                BindPair(BindKey::Numeric(idx), _) => Some(idx + 1),
                _ => None,
            })
            .unwrap_or(1);
        self.bindings.0.push(BindPair(
            BindKey::Numeric(next_idx),
            Into::<Value>::into(value).into(),
        ));
        self
    }

    /// Bind based on a name.
    pub fn bind_named(mut self, name: impl Into<String>, value: impl Into<Value>) -> Self {
        self.bindings.0.push(BindPair(
            BindKey::String(name.into()),
            Into::<Value>::into(value).into(),
        ));
        self
    }

    /// Executes the query returning all rows collected as a `Vec`.
    pub fn execute(self) -> Vec<Vec<Value>> {
        self.execute_iter().collect()
    }

    /// Executes the query returning an iterator over rows.
    ///
    /// The query will not be executed until the iter is iterated upon.
    pub fn execute_iter(self) -> QueryIter {
        let encoded = bincode::serialize(&self.bindings).unwrap();
        unsafe { bindings::bind_value(self.id, encoded.as_ptr() as u32, encoded.len() as u32) };

        QueryIter { statement: self }
    }
}

impl Drop for Statement {
    fn drop(&mut self) {
        unsafe {
            bindings::sqlite3_finalize(self.id);
        }
    }
}

/// Iterator for iterating query result rows.
pub struct QueryIter {
    statement: Statement,
}

impl Iterator for QueryIter {
    type Item = Vec<Value>;

    fn next(&mut self) -> Option<Self::Item> {
        match SqliteCode::from_code(unsafe { bindings::sqlite3_step(self.statement.id) }) {
            Some(SqliteCode::Done) => return None,
            Some(SqliteCode::Row) => {}
            Some(code) => panic!("unexpected code {code:?} from lunatic::sqlite::sqlite3_step. Expected SQLITE_DONE or SQLITE_ROW"),
            None => panic!("unexpected code from lunatic::sqlite::sqlite3_step. Expected SQLITE_DONE or SQLITE_ROW"),
        }

        Some(
            call_host_alloc::<SqliteRow>(|len_ptr| unsafe {
                bindings::read_row(self.statement.id, len_ptr)
            })
            .unwrap()
            .0
            .into_iter()
            .map(|value| value.into())
            .collect(),
        )
    }
}