1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
/*!
Bindings to the DeepSpeech library
*/

extern crate libc;
extern crate deepspeech_sys;

use std::ffi::CStr;
use std::path::Path;
use std::ops::Drop;
use std::ptr;
use std::mem::forget;
use libc::free;
use deepspeech_sys as ds;

pub struct Model {
	model :* mut ds::ModelState,
}

fn path_to_buf(p :&Path) -> Vec<u8> {
	let s = p.to_str().unwrap();
	let mut v = Vec::with_capacity(s.len());
	v.extend_from_slice(s.as_bytes());
	v.push(0);
	v
}

impl Model {
	/// Load a DeepSpeech model from the specified model and alphabet file paths
	pub fn load_from_files(model_path :&Path, n_cep :u16, n_context :u16,
			alphabet_path :&Path, beam_width :u16) -> Result<Self, ()> {
		let mp = path_to_buf(model_path);
		let ap = path_to_buf(alphabet_path);
		let mut model = ptr::null_mut();
		let ret = unsafe {
			ds::DS_CreateModel(
				mp.as_ptr() as _,
				n_cep as _,
				n_context as _,
				ap.as_ptr() as _,
				beam_width as _,
				&mut model)
		};
		if ret != 0 {
			return Err(());
		}
		Ok(Model {
			model,
		})
	}

	/// Load a KenLM language model from a file and enable decoding using beam scoring
	pub fn enable_decoder_with_lm(&mut self, alphabet_path :&Path,
			language_model_path :&Path, trie_path :&Path,
			weight :f32,
			valid_word_count_weight :f32) {
		let ap = path_to_buf(alphabet_path);
		let lp = path_to_buf(language_model_path);
		let tp = path_to_buf(trie_path);
		unsafe {
			ds::DS_EnableDecoderWithLM(
				self.model,
				ap.as_ptr() as _,
				lp.as_ptr() as _,
				tp.as_ptr() as _,
				weight,
				valid_word_count_weight);
		}
	}

	/// Perform speech-to-text using the model
	///
	/// The input buffer must consist of mono 16-bit samples.
	/// The sample rate is not freely chooseable but a property
	/// of the model files.
	pub fn speech_to_text(&mut self, buffer :&[i16],
			sample_rate :u32) -> Result<String, std::string::FromUtf8Error> {
		let r = unsafe {
			let ptr = ds::DS_SpeechToText(
				self.model,
				buffer.as_ptr(),
				buffer.len() as _,
				sample_rate as _);
			let s = CStr::from_ptr(ptr);
			let mut v = Vec::new();
			v.extend_from_slice(s.to_bytes());
			free(ptr as _);
			v
		};
		String::from_utf8(r)
	}

	/// Set up a state for streaming inference
	pub fn setup_stream(&mut self, pre_alloc_frames :u32, sample_rate :u32) -> Result<Stream, ()> {
		let mut ptr = ptr::null_mut();
		let ret = unsafe {
			ds::DS_SetupStream(
				self.model,
				pre_alloc_frames as _,
				sample_rate as _,
				&mut ptr,
			)
		};
		if ret != 0 {
			return Err(());
		}
		Ok(Stream {
			stream : ptr
		})
	}
}

impl Drop for Model {
	fn drop(&mut self) {
		unsafe {
			ds::DS_DestroyModel(self.model);
		}
	}
}

pub struct Stream {
	stream :* mut ds::StreamingState,
}

impl Stream {
	/// Feed audio samples to the stream
	///
	/// The input buffer must consist of mono 16-bit samples.
	pub fn feed_audio(&mut self, buffer :&[i16]) {
		unsafe {
			ds::DS_FeedAudioContent(self.stream, buffer.as_ptr(), buffer.len() as _);
		}
	}

	/// Decodes the intermediate state of what has been spoken up until now
	///
	/// Note that as of DeepSpeech version 0.2.0,
	/// this function is non-trivial as the decoder can't do streaming yet.
	pub fn intermediate_decode(&mut self) -> Result<String, std::string::FromUtf8Error> {
		let r = unsafe {
			let ptr = ds::DS_IntermediateDecode(self.stream);
			let s = CStr::from_ptr(ptr);
			let mut v = Vec::new();
			v.extend_from_slice(s.to_bytes());
			free(ptr as _);
			v
		};
		String::from_utf8(r)
	}

	/// Deallocates the stream and returns the decoded text
	pub fn finish(self) -> Result<String, std::string::FromUtf8Error> {
		let r = unsafe {
			let ptr = ds::DS_FinishStream(self.stream);
			let s = CStr::from_ptr(ptr);
			let mut v = Vec::new();
			v.extend_from_slice(s.to_bytes());
			free(ptr as _);
			v
		};
		// Don't run the destructor for self,
		// as DS_FinishStream already does it for us
		forget(self);
		String::from_utf8(r)
	}
}

impl Drop for Stream {
	fn drop(&mut self) {
		unsafe {
			ds::DS_DiscardStream(self.stream);
		}
	}
}