1mod dataframe;
113mod schema;
114
115pub use polars::prelude as polars_prelude;
116
117use polars::prelude::*;
118use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
119use rusqlite::{Connection, OpenFlags, Row, ToSql};
120use std::collections::{HashMap, HashSet};
121use std::path::PathBuf;
122use std::time::Duration;
123use time::OffsetDateTime;
124
125use dataframe::query_to_dataframe;
126
127#[derive(Debug, Clone)]
132pub enum DatabasePath {
133 RealPath(PathBuf),
135 #[cfg(test)]
137 InMemory,
138}
139
140impl DatabasePath {
141 fn open_connection(&self) -> Result<Connection, BearError> {
145 match self {
146 DatabasePath::RealPath(path) => {
147 let conn = Connection::open_with_flags(
151 path,
152 OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
153 )?;
154
155 conn.busy_timeout(Duration::from_millis(5000))?;
157
158 conn.pragma_update(None, "query_only", "ON")?;
160
161 Ok(conn)
162 }
163 #[cfg(test)]
164 DatabasePath::InMemory => {
165 let conn = Connection::open_in_memory()?;
166 schema::setup_test_schema(&conn)?;
167 Ok(conn)
168 }
169 }
170 }
171}
172
173#[derive(Debug, thiserror::Error)]
174pub enum BearError {
175 #[error("Unable to load users home directory")]
176 NoHomeDirectory,
177 #[error("SQL Error: {source}")]
178 SqlError {
179 #[from]
180 source: rusqlite::Error,
181 },
182 #[error("Polars Error: {source}")]
183 PolarsError {
184 #[from]
185 source: PolarsError,
186 },
187}
188
189#[derive(Debug, Clone)]
191pub struct NotesQuery {
192 limit: Option<u32>,
193 include_trashed: bool,
194 include_archived: bool,
195}
196
197impl Default for NotesQuery {
198 fn default() -> Self {
199 Self {
200 limit: Some(10),
201 include_trashed: false,
202 include_archived: false,
203 }
204 }
205}
206
207impl NotesQuery {
208 pub fn new() -> Self {
210 Self::default()
211 }
212
213 pub fn limit(
215 mut self,
216 limit: u32,
217 ) -> Self {
218 self.limit = Some(limit);
219 self
220 }
221
222 pub fn no_limit(mut self) -> Self {
224 self.limit = None;
225 self
226 }
227
228 pub fn include_trashed(mut self) -> Self {
230 self.include_trashed = true;
231 self
232 }
233
234 pub fn include_archived(mut self) -> Self {
236 self.include_archived = true;
237 self
238 }
239
240 pub fn include_all(mut self) -> Self {
242 self.include_trashed = true;
243 self.include_archived = true;
244 self
245 }
246}
247
248pub struct BearDb {
250 db_path: DatabasePath,
251 _metadata: schema::BearDbMetadata,
252 normalizing_cte: String,
253}
254
255impl BearDb {
256 pub fn new() -> Result<Self, BearError> {
259 let home_dir = dirs::home_dir().ok_or(BearError::NoHomeDirectory)?;
260
261 let db_path = home_dir.join(
262 "Library/Group Containers/9K33E3U3T4.net.shinyfrog.bear/Application Data/database.sqlite",
263 );
264
265 Self::new_with_path(DatabasePath::RealPath(db_path))
266 }
267
268 pub(crate) fn new_with_path(db_path: DatabasePath) -> Result<Self, BearError> {
271 let connection = db_path.open_connection()?;
273
274 let metadata = schema::discover_metadata(&connection)?;
276
277 let normalizing_cte = schema::generate_normalizing_cte(&metadata);
279
280 drop(connection);
282
283 Ok(BearDb {
284 db_path,
285 _metadata: metadata,
286 normalizing_cte,
287 })
288 }
289
290 fn with_connection<F, R>(
293 &self,
294 f: F,
295 ) -> Result<R, BearError>
296 where
297 F: FnOnce(&Queryable) -> Result<R, BearError>,
298 {
299 let connection = self.db_path.open_connection()?;
301
302 let queryable = Queryable::new(&connection, &self.normalizing_cte);
304
305 f(&queryable)
308 }
309
310 pub fn tags(&self) -> Result<BearTags, BearError> {
312 self.with_connection(|queryable| {
313 let mut statement = queryable.prepare(
314 r"
315 SELECT
316 id,
317 name,
318 modified
319 FROM tags
320 ORDER BY name ASC",
321 )?;
322
323 let results: rusqlite::Result<Vec<BearTag>> = statement
324 .query_map([], |row| {
325 Ok(BearTag {
326 id: row.get("id")?,
327 name: row.get("name")?,
328 modified: row.get("modified")?,
329 })
330 })?
331 .collect();
332
333 let tags = results?.into_iter().map(|tag| (tag.id, tag)).collect();
334
335 Ok(BearTags { tags })
336 })
337 }
338
339 pub fn notes(
359 &self,
360 query: NotesQuery,
361 ) -> Result<Vec<BearNote>, BearError> {
362 self.with_connection(|queryable| {
363 let mut where_clauses = Vec::new();
365 if !query.include_trashed {
366 where_clauses.push("is_trashed <> 1");
367 }
368 if !query.include_archived {
369 where_clauses.push("is_archived <> 1");
370 }
371
372 let where_clause = if where_clauses.is_empty() {
373 String::new()
374 } else {
375 format!("WHERE {}", where_clauses.join(" AND "))
376 };
377
378 let limit_clause = query
379 .limit
380 .map(|l| format!("LIMIT {}", l))
381 .unwrap_or_default();
382
383 let query = format!(
384 r"
385 SELECT
386 id,
387 unique_id,
388 title,
389 content,
390 modified,
391 created,
392 is_pinned
393 FROM notes
394 {}
395 ORDER BY modified DESC
396 {}",
397 where_clause, limit_clause
398 );
399
400 let mut statement = queryable.prepare(&query)?;
401
402 let results: rusqlite::Result<Vec<BearNote>> =
403 statement.query_map([], note_from_row)?.collect();
404
405 Ok(results?)
406 })
407 }
408
409 pub fn note_links(
411 &self,
412 from: BearNoteId,
413 ) -> Result<Vec<BearNote>, BearError> {
414 self.with_connection(|queryable| {
415 let mut statement = queryable.prepare(
416 r"
417 SELECT
418 n.id,
419 n.unique_id,
420 n.title,
421 n.content,
422 n.modified,
423 n.created,
424 n.is_pinned
425 FROM notes as n
426 INNER JOIN note_links as nl ON nl.to_note_id = n.id
427 WHERE n.is_trashed <> 1 AND n.is_archived <> 1 AND nl.from_note_id = ?
428 ORDER BY n.modified DESC",
429 )?;
430
431 let results: rusqlite::Result<Vec<BearNote>> =
432 statement.query_map([from], note_from_row)?.collect();
433
434 Ok(results?)
435 })
436 }
437
438 pub fn note_tags(
440 &self,
441 from: BearNoteId,
442 ) -> Result<HashSet<BearTagId>, BearError> {
443 self.with_connection(|queryable| {
444 let mut statement = queryable.prepare(
445 r"
446 SELECT
447 tag_id
448 FROM note_tags
449 WHERE note_id = ?",
450 )?;
451
452 let results: rusqlite::Result<HashSet<BearTagId>> = statement
453 .query_map([from], |row| row.get("tag_id"))?
454 .collect();
455
456 Ok(results?)
457 })
458 }
459
460 pub fn query(
493 &self,
494 sql: &str,
495 ) -> Result<DataFrame, BearError> {
496 self.with_connection(|queryable| query_to_dataframe(queryable, sql))
497 }
498}
499
500pub struct Queryable<'a> {
503 conn: &'a Connection,
504 normalizing_cte: &'a str,
505}
506
507impl<'a> Queryable<'a> {
508 fn new(
510 conn: &'a Connection,
511 normalizing_cte: &'a str,
512 ) -> Self {
513 Self {
514 conn,
515 normalizing_cte,
516 }
517 }
518
519 #[cfg(test)]
523 pub(crate) fn new_for_test(
524 conn: &'a Connection,
525 normalizing_cte: &'a str,
526 ) -> Self {
527 Self::new(conn, normalizing_cte)
528 }
529
530 pub fn prepare(
533 &self,
534 user_sql: &str,
535 ) -> rusqlite::Result<rusqlite::Statement<'a>> {
536 let full_sql = format!("{}\n{}", self.normalizing_cte, user_sql);
537 self.conn.prepare(&full_sql)
538 }
539}
540
541#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
542pub struct DbId(i64);
543
544#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
545pub struct BearNoteId(DbId);
546
547#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
548pub struct BearTagId(DbId);
549
550impl FromSql for DbId {
551 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
552 Ok(Self(value.as_i64()?))
553 }
554}
555
556impl FromSql for BearNoteId {
557 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
558 Ok(Self(FromSql::column_result(value)?))
559 }
560}
561
562impl FromSql for BearTagId {
563 fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
564 Ok(Self(FromSql::column_result(value)?))
565 }
566}
567
568impl ToSql for DbId {
569 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
570 self.0.to_sql()
571 }
572}
573
574impl ToSql for BearNoteId {
575 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
576 self.0.to_sql()
577 }
578}
579
580#[derive(Debug, Clone)]
581pub struct BearTag {
582 id: BearTagId,
583 name: String,
584 modified: Option<OffsetDateTime>,
585}
586
587impl BearTag {
588 pub fn id(&self) -> BearTagId {
589 self.id
590 }
591
592 pub fn name(&self) -> &str {
593 &self.name
594 }
595
596 pub fn modified(&self) -> Option<OffsetDateTime> {
597 self.modified
598 }
599}
600
601#[derive(Debug)]
602pub struct BearTags {
603 tags: HashMap<BearTagId, BearTag>,
604}
605
606impl BearTags {
607 pub fn get(
608 &self,
609 tag_id: &BearTagId,
610 ) -> Option<&BearTag> {
611 self.tags.get(tag_id)
612 }
613
614 pub fn count(&self) -> usize {
615 self.tags.len()
616 }
617
618 pub fn iter(&self) -> impl Iterator<Item = &BearTag> {
619 self.tags.values()
620 }
621
622 pub fn names(
623 &self,
624 tag_ids: &HashSet<BearTagId>,
625 ) -> HashSet<String> {
626 tag_ids
627 .iter()
628 .filter_map(|id| self.get(id).map(|t| t.name.clone()))
629 .collect()
630 }
631}
632
633#[derive(Debug)]
634pub struct BearNote {
635 id: BearNoteId,
636 unique_id: String,
637 title: String,
638 content: String,
639 modified: OffsetDateTime,
640 created: OffsetDateTime,
641 is_pinned: bool,
642}
643
644impl BearNote {
645 pub fn id(&self) -> BearNoteId {
646 self.id
647 }
648
649 pub fn unique_id(&self) -> &str {
650 &self.unique_id
651 }
652
653 pub fn title(&self) -> &str {
654 &self.title
655 }
656
657 pub fn content(&self) -> &str {
658 &self.content
659 }
660
661 pub fn modified(&self) -> OffsetDateTime {
662 self.modified
663 }
664
665 pub fn created(&self) -> OffsetDateTime {
666 self.created
667 }
668
669 pub fn is_pinned(&self) -> bool {
670 self.is_pinned
671 }
672}
673
674fn note_from_row(row: &Row) -> rusqlite::Result<BearNote> {
675 Ok(BearNote {
676 id: row.get("id")?,
677 unique_id: row.get("unique_id")?,
678 title: row.get("title")?,
679 content: row.get("content")?,
680 created: row.get("created")?,
681 modified: row.get("modified")?,
682 is_pinned: row.get("is_pinned")?,
683 })
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689
690 #[test]
692 fn test_beardb_with_inmemory() {
693 let db = BearDb::new_with_path(DatabasePath::InMemory).unwrap();
695
696 let tags = db.tags().unwrap();
698 assert_eq!(tags.count(), 2); let notes = db.notes(NotesQuery::default()).unwrap();
701 assert_eq!(notes.len(), 2); let all_notes = db
705 .notes(NotesQuery::new().include_all().no_limit())
706 .unwrap();
707 assert_eq!(all_notes.len(), 3); let df = db
711 .query("SELECT id, title FROM notes WHERE is_trashed = 0")
712 .unwrap();
713 assert_eq!(df.height(), 2); assert_eq!(df.width(), 2); let df = db.query("SELECT COUNT(*) as count FROM notes").unwrap();
718 assert_eq!(df.height(), 1);
719 assert_eq!(df.width(), 1);
720
721 let series = df.column("count").unwrap();
723 let value = series.get(0).unwrap();
724 match value {
725 AnyValue::Int64(n) => assert_eq!(n, 3),
726 _ => panic!("Expected Int64, got: {:?}", value),
727 }
728
729 let df = db
731 .query(
732 r"
733 SELECT n.title, t.name as tag_name
734 FROM notes n
735 JOIN note_tags nt ON n.id = nt.note_id
736 JOIN tags t ON nt.tag_id = t.id
737 ",
738 )
739 .unwrap();
740 assert_eq!(df.height(), 2); }
742}