1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
 * File: logger.rs
 * Project: link
 * Created Date: 10/05/2023
 * Author: Shun Suzuki
 * -----
 * Last Modified: 24/07/2023
 * Modified By: Shun Suzuki (suzuki@hapis.k.u-tokyo.ac.jp)
 * -----
 * Copyright (c) 2023 Shun Suzuki. All rights reserved.
 *
 */

use std::sync::{Arc, RwLock};

use spdlog::{
    formatter::{Formatter, FullFormatter},
    prelude::*,
    sink::Sink,
    ErrorHandler, StringBuf,
};

type SinkErrorHandler = Option<ErrorHandler>;

/// logger with custom output and flush callback
pub struct CustomSink<O, F>
where
    O: Fn(&str) -> spdlog::Result<()> + Send + Sync,
    F: Fn() -> spdlog::Result<()> + Send + Sync,
{
    level_filter: RwLock<LevelFilter>,
    formatter: RwLock<Box<dyn Formatter>>,
    error_handler: RwLock<SinkErrorHandler>,
    out: O,
    flush: F,
}

impl<O, F> CustomSink<O, F>
where
    O: Fn(&str) -> spdlog::Result<()> + Send + Sync,
    F: Fn() -> spdlog::Result<()> + Send + Sync,
{
    pub fn new(out: O, flush: F) -> Self {
        Self {
            level_filter: RwLock::new(LevelFilter::All),
            formatter: RwLock::new(Box::new(FullFormatter::new())),
            error_handler: RwLock::new(None),
            out,
            flush,
        }
    }
}

impl<O, F> Sink for CustomSink<O, F>
where
    O: Fn(&str) -> spdlog::Result<()> + Send + Sync,
    F: Fn() -> spdlog::Result<()> + Send + Sync,
{
    fn log(&self, record: &spdlog::Record) -> spdlog::Result<()> {
        if !self.should_log(record.level()) {
            return Ok(());
        }
        let mut string_buf = StringBuf::new();
        self.formatter
            .read()
            .unwrap()
            .format(record, &mut string_buf)?;
        (self.out)(string_buf.as_str())
    }

    fn flush(&self) -> spdlog::Result<()> {
        (self.flush)()
    }

    fn level_filter(&self) -> LevelFilter {
        *self.level_filter.read().unwrap()
    }

    fn set_level_filter(&self, level_filter: LevelFilter) {
        *self.level_filter.write().unwrap() = level_filter;
    }

    fn set_formatter(&self, formatter: Box<dyn spdlog::formatter::Formatter>) {
        *self.formatter.write().unwrap() = formatter;
    }

    fn set_error_handler(&self, handler: Option<spdlog::ErrorHandler>) {
        *self.error_handler.write().unwrap() = handler;
    }
}

fn get_default_sink() -> Vec<Arc<dyn Sink>> {
    spdlog::default_logger().sinks().to_owned()
}

fn get_custom_sink<O, F>(out: O, flush: F) -> Vec<Arc<dyn Sink>>
where
    O: Fn(&str) -> spdlog::Result<()> + Send + Sync + 'static,
    F: Fn() -> spdlog::Result<()> + Send + Sync + 'static,
{
    vec![Arc::new(CustomSink::new(out, flush))]
}

/// Create default logger
pub fn get_logger() -> Logger {
    Logger::builder()
        .sinks(get_default_sink())
        .name("AUTD3")
        .build()
        .unwrap()
}

/// Create logger with custom output and flush callback
pub fn get_logger_with_custom_func<O, F>(out: O, flush: F) -> Logger
where
    O: Fn(&str) -> spdlog::Result<()> + Send + Sync + 'static,
    F: Fn() -> spdlog::Result<()> + Send + Sync + 'static,
{
    Logger::builder()
        .sinks(get_custom_sink(out, flush))
        .name("AUTD3")
        .build()
        .unwrap()
}