1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
use crate::{
    detections::Detections,
    error::Error,
    image::IntoCowImage,
    layers::{Layer, Layers},
    utils,
};
use darknet_sys as sys;

use std::{
    ffi::c_void,
    os::raw::c_int,
    path::Path,
    ptr::{self, NonNull},
    slice,
};

/// The network wrapper type for Darknet.
pub struct Network {
    net: NonNull<sys::network>,
}

impl Network {
    /// Build the network instance from a configuration file and an optional weights file.
    ///
    /// This will abort the program with an exit code of 1 if any of the following occur.
    /// - The config has no sections.
    /// - The first section of the config is not `[net]` or `[network]`.
    /// - `fopen` fails on \[weights\] (if provided).
    /// - The weights file is invalid
    ///
    /// Returns an [Err] if [cfg] or \[weights\] (if provided) contain a null byte.
    pub fn load<C, W>(cfg: C, weights: Option<W>, clear: bool) -> Result<Network, Error>
    where
        C: AsRef<Path>,
        W: AsRef<Path>,
    {
        let weights_cstr = weights
            .map(|path| utils::path_to_cstring_or_error(path.as_ref()))
            .transpose()?;

        let cfg_cstr = utils::path_to_cstring_or_error(cfg.as_ref())?;

        let clear = c_int::from(clear);

        let ptr = unsafe {
            let raw_weights = weights_cstr
                .as_ref()
                .map_or(ptr::null_mut(), |cstr| cstr.as_ptr() as *mut _);
            let raw_cfg = cfg_cstr.as_ptr() as *mut _;
            sys::load_network(raw_cfg, raw_weights, clear)
        };

        let net = NonNull::new(ptr).ok_or_else(|| Error::InternalError {
            reason: "failed to load model".into(),
        })?;

        // drop paths here to avoid early deallocation
        drop(cfg_cstr);
        drop(weights_cstr);

        Ok(Self { net })
    }

    /// Get network input width.
    pub fn input_width(&self) -> usize {
        unsafe { self.net.as_ref().w as usize }
    }

    /// Get network input height.
    pub fn input_height(&self) -> usize {
        unsafe { self.net.as_ref().h as usize }
    }

    /// Get network input channels.
    pub fn input_channels(&self) -> usize {
        unsafe { self.net.as_ref().c as usize }
    }

    /// Get network input shape tuple (channels, height, width).
    pub fn input_shape(&self) -> (usize, usize, usize) {
        (
            self.input_channels(),
            self.input_height(),
            self.input_width(),
        )
    }

    /// Get the number of layers.
    pub fn num_layers(&self) -> usize {
        unsafe { self.net.as_ref().n as usize }
    }

    /// Get network layers.
    pub fn layers(&self) -> Layers {
        let layers = unsafe { slice::from_raw_parts(self.net.as_ref().layers, self.num_layers()) };
        Layers { layers }
    }

    /// Get layer by index.
    pub fn get_layer(&self, index: usize) -> Option<Layer> {
        if index >= self.num_layers() {
            return None;
        }

        unsafe {
            let layer = self.net.as_ref().layers.add(index).as_ref().unwrap();
            Some(Layer { layer })
        }
    }

    /// Run inference on an image.
    pub fn predict<'a, M>(
        &mut self,
        image: M,
        thresh: f32,
        hier_thres: f32,
        nms: f32,
        use_letter_box: bool,
    ) -> Detections
    where
        M: IntoCowImage<'a>,
    {
        let cow = image.into_cow_image();

        unsafe {
            let output_layer = self
                .net
                .as_ref()
                .layers
                .add(self.num_layers() - 1)
                .as_ref()
                .unwrap();

            // run prediction
            if use_letter_box {
                sys::network_predict_image_letterbox(self.net.as_ptr(), cow.image);
            } else {
                sys::network_predict_image(self.net.as_ptr(), cow.image);
            }

            let mut nboxes: c_int = 0;
            let dets_ptr = sys::get_network_boxes(
                self.net.as_mut(),
                cow.width() as c_int,
                cow.height() as c_int,
                thresh,
                hier_thres,
                ptr::null_mut(),
                1,
                &mut nboxes,
                use_letter_box as c_int,
            );
            let dets = NonNull::new(dets_ptr).unwrap();

            // NMS sort
            if nms != 0.0 {
                if output_layer.nms_kind == sys::NMS_KIND_DEFAULT_NMS {
                    sys::do_nms_sort(dets.as_ptr(), nboxes, output_layer.classes, nms);
                } else {
                    sys::diounms_sort(
                        dets.as_ptr(),
                        nboxes,
                        output_layer.classes,
                        nms,
                        output_layer.nms_kind,
                        output_layer.beta_nms,
                    );
                }
            }

            Detections {
                detections: dets,
                n_detections: nboxes as usize,
            }
        }
    }
}

impl Drop for Network {
    fn drop(&mut self) {
        unsafe {
            let ptr = self.net.as_ptr();
            sys::free_network(*ptr);

            // The network* pointer was allocated by calloc
            // We have to deallocate it manually
            libc::free(ptr as *mut c_void);
        }
    }
}

unsafe impl Send for Network {}