1use crate::model::events::Events;
2use crate::model::http::UpdatePlayer;
3use crate::model::player::ConnectionInfo;
4use crate::prelude::PlayerContext;
5
6use futures::future::BoxFuture;
7use parking_lot::RwLock;
8use pyo3::prelude::*;
9use pyo3::types::PyList;
10
11#[pymodule]
12pub fn client(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
13 m.add_class::<crate::client::LavalinkClient>()?;
14
15 Ok(())
16}
17
18fn raw_event(
19 _: crate::client::LavalinkClient,
20 session_id: String,
21 event: &serde_json::Value,
22) -> BoxFuture<()> {
23 Box::pin(async move {
24 debug!("{:?} -> {:?}", session_id, event);
25 })
26}
27
28#[pymethods]
29impl crate::client::LavalinkClient {
30 #[pyo3(name = "new")]
31 #[pyo3(signature = (events, nodes, strategy, user_data=None))]
32 #[staticmethod]
33 fn new_py<'a>(
34 py: Python<'a>,
35 events: PyObject,
36 nodes: Vec<crate::node::NodeBuilder>,
37 strategy: super::model::client::NodeDistributionStrategyPy,
38 user_data: Option<PyObject>,
39 ) -> PyResult<Bound<'a, PyAny>> {
40 let current_loop = pyo3_async_runtimes::get_running_loop(py)?;
41 let loop_ref = PyObject::from(current_loop);
42
43 let event_handler = crate::python::event::EventHandler {
44 inner: events,
45 current_loop: loop_ref,
46 };
47
48 let events = Events {
49 raw: Some(raw_event),
50 event_handler: Some(event_handler),
51 ..Default::default()
52 };
53
54 pyo3_async_runtimes::tokio::future_into_py_with_locals(
55 py,
56 pyo3_async_runtimes::tokio::get_current_locals(py)?,
57 async move {
58 if let Some(data) = user_data {
59 Ok(crate::client::LavalinkClient::new_with_data(
60 events,
61 nodes,
62 strategy.inner,
63 std::sync::Arc::new(RwLock::new(data)),
64 )
65 .await)
66 } else {
67 Ok(crate::client::LavalinkClient::new_with_data(
68 events,
69 nodes,
70 strategy.inner,
71 std::sync::Arc::new(RwLock::new(Python::with_gil(|py| py.None()))),
72 )
73 .await)
74 }
75 },
76 )
77 }
78
79 #[pyo3(name = "create_player_context")]
80 #[pyo3(signature = (guild_id, endpoint, token, session_id, user_data=None))]
81 fn create_player_context_py<'a>(
82 &self,
83 py: Python<'a>,
84 guild_id: super::model::PyGuildId,
85 endpoint: String,
86 token: String,
87 session_id: String,
88 user_data: Option<PyObject>,
89 ) -> PyResult<Bound<'a, PyAny>> {
90 let client = self.clone();
91
92 pyo3_async_runtimes::tokio::future_into_py_with_locals(
93 py,
94 pyo3_async_runtimes::tokio::get_current_locals(py)?,
95 async move {
96 if let Some(data) = user_data {
97 Ok(client
98 .create_player_context_with_data(
99 guild_id,
100 ConnectionInfo {
101 endpoint,
102 token,
103 session_id,
104 },
105 std::sync::Arc::new(RwLock::new(data)),
106 )
107 .await?)
108 } else {
109 Ok(client
110 .create_player_context_with_data(
111 guild_id,
112 ConnectionInfo {
113 endpoint,
114 token,
115 session_id,
116 },
117 std::sync::Arc::new(RwLock::new(Python::with_gil(|py| py.None()))),
118 )
119 .await?)
120 }
121 },
122 )
123 }
124
125 #[pyo3(name = "create_player")]
126 fn create_player_py<'a>(
127 &self,
128 py: Python<'a>,
129 guild_id: super::model::PyGuildId,
130 endpoint: String,
131 token: String,
132 session_id: String,
133 ) -> PyResult<Bound<'a, PyAny>> {
134 let client = self.clone();
135
136 pyo3_async_runtimes::tokio::future_into_py(py, async move {
137 let player = client
138 .create_player(
139 guild_id,
140 ConnectionInfo {
141 endpoint,
142 token,
143 session_id,
144 },
145 )
146 .await?;
147
148 Ok(Python::with_gil(|_py| player))
149 })
150 }
151
152 #[pyo3(name = "get_player_context")]
153 fn get_player_context_py<'a>(
154 &self,
155 guild_id: super::model::PyGuildId,
156 ) -> PyResult<Option<PlayerContext>> {
157 let player = self.get_player_context(guild_id);
158
159 Ok(player)
160 }
161
162 #[pyo3(name = "get_node_by_index")]
163 fn get_node_by_index_py(&self, idx: usize) -> Option<super::node::Node> {
164 self.get_node_by_index(idx)
165 .map(|x| super::node::Node { inner: x })
166 }
167
168 #[pyo3(name = "get_node_for_guild")]
169 pub fn get_node_for_guild_py<'a>(
170 &self,
171 py: Python<'a>,
172 guild_id: super::model::PyGuildId,
173 ) -> PyResult<Bound<'a, PyAny>> {
174 let client = self.clone();
175
176 pyo3_async_runtimes::tokio::future_into_py(py, async move {
177 let res = client.get_node_for_guild(guild_id).await;
178
179 Ok(Python::with_gil(|_py| super::node::Node { inner: res }))
180 })
181 }
182
183 #[pyo3(name = "load_tracks")]
184 fn load_tracks_py<'a>(
185 &self,
186 py: Python<'a>,
187 guild_id: super::model::PyGuildId,
188 identifier: String,
189 ) -> PyResult<Bound<'a, PyAny>> {
190 let client = self.clone();
191
192 pyo3_async_runtimes::tokio::future_into_py(py, async move {
193 let tracks = client.load_tracks(guild_id, &identifier).await?;
194
195 use crate::model::track::TrackLoadData::*;
196
197 Python::with_gil(|py| {
198 let track_data: Option<PyObject> = match tracks.data {
199 Some(Track(x)) => Some(x.into_pyobject(py).unwrap().into_any()),
200 Some(Playlist(x)) => Some(x.into_pyobject(py).unwrap().into_any()),
201 Some(Search(x)) => {
202 let l = PyList::empty(py);
203 for i in x {
204 l.append(i.into_pyobject(py).unwrap())?;
205 }
206
207 Some(l.into_pyobject(py).unwrap().into_any())
208 }
209 Some(Error(x)) => Some(x.into_pyobject(py).unwrap().into_any()),
210 None => None,
211 }
212 .map(|x| x.into());
213
214 Ok(super::model::track::Track {
215 load_type: tracks.load_type,
216 data: track_data,
217 })
218 })
219 })
220 }
221
222 #[pyo3(name = "delete_player")]
223 fn delete_player_py<'a>(
224 &self,
225 py: Python<'a>,
226 guild_id: super::model::PyGuildId,
227 ) -> PyResult<Bound<'a, PyAny>> {
228 let client = self.clone();
229
230 pyo3_async_runtimes::tokio::future_into_py(py, async move {
231 client.delete_player(guild_id).await?;
232
233 Ok(())
234 })
235 }
236
237 #[pyo3(name = "delete_all_player_contexts")]
238 fn delete_all_player_contexts_py<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
239 let client = self.clone();
240
241 pyo3_async_runtimes::tokio::future_into_py(py, async move {
242 client.delete_all_player_contexts().await?;
243
244 Ok(())
245 })
246 }
247
248 #[pyo3(name = "update_player")]
249 fn update_player_py<'a>(
250 &self,
251 py: Python<'a>,
252 guild_id: super::model::PyGuildId,
253 update_player: UpdatePlayer,
254 no_replace: bool,
255 ) -> PyResult<Bound<'a, PyAny>> {
256 let client = self.clone();
257
258 pyo3_async_runtimes::tokio::future_into_py(py, async move {
259 let player = client
260 .update_player(guild_id, &update_player, no_replace)
261 .await?;
262
263 Ok(player)
264 })
265 }
266
267 #[pyo3(name = "decode_track")]
268 fn decode_track_py<'a>(
269 &self,
270 py: Python<'a>,
271 guild_id: super::model::PyGuildId,
272 track: String,
273 ) -> PyResult<Bound<'a, PyAny>> {
274 let client = self.clone();
275
276 pyo3_async_runtimes::tokio::future_into_py(py, async move {
277 let track = client.decode_track(guild_id, &track).await?;
278
279 Ok(track)
280 })
281 }
282
283 #[pyo3(name = "decode_tracks")]
284 fn decode_tracks_py<'a>(
285 &self,
286 py: Python<'a>,
287 guild_id: super::model::PyGuildId,
288 tracks: Vec<String>,
289 ) -> PyResult<Bound<'a, PyAny>> {
290 let client = self.clone();
291
292 pyo3_async_runtimes::tokio::future_into_py(py, async move {
293 let tracks = client.decode_tracks(guild_id, &tracks).await?;
294
295 Ok(tracks)
296 })
297 }
298
299 #[pyo3(name = "request_version")]
300 fn request_version_py<'a>(
301 &self,
302 py: Python<'a>,
303 guild_id: super::model::PyGuildId,
304 ) -> PyResult<Bound<'a, PyAny>> {
305 let client = self.clone();
306
307 pyo3_async_runtimes::tokio::future_into_py(py, async move {
308 let version = client.request_version(guild_id).await?;
309
310 Ok(version)
311 })
312 }
313
314 #[pyo3(name = "request_info")]
315 fn request_info_py<'a>(
316 &self,
317 py: Python<'a>,
318 guild_id: super::model::PyGuildId,
319 ) -> PyResult<Bound<'a, PyAny>> {
320 let client = self.clone();
321
322 pyo3_async_runtimes::tokio::future_into_py(py, async move {
323 let info = client.request_info(guild_id).await?;
324
325 Ok(info)
326 })
327 }
328
329 #[pyo3(name = "request_stats")]
330 fn request_stats_py<'a>(
331 &self,
332 py: Python<'a>,
333 guild_id: super::model::PyGuildId,
334 ) -> PyResult<Bound<'a, PyAny>> {
335 let client = self.clone();
336
337 pyo3_async_runtimes::tokio::future_into_py(py, async move {
338 let stats = client.request_stats(guild_id).await?;
339
340 Ok(stats)
341 })
342 }
343
344 #[pyo3(name = "request_player")]
345 fn request_player_py<'a>(
346 &self,
347 py: Python<'a>,
348 guild_id: super::model::PyGuildId,
349 ) -> PyResult<Bound<'a, PyAny>> {
350 let client = self.clone();
351
352 pyo3_async_runtimes::tokio::future_into_py(py, async move {
353 let player = client.request_player(guild_id).await?;
354
355 Ok(player)
356 })
357 }
358
359 #[pyo3(name = "request_all_players")]
360 fn request_all_players_py<'a>(
361 &self,
362 py: Python<'a>,
363 guild_id: super::model::PyGuildId,
364 ) -> PyResult<Bound<'a, PyAny>> {
365 let client = self.clone();
366
367 pyo3_async_runtimes::tokio::future_into_py(py, async move {
368 let players = client.request_all_players(guild_id).await?;
369
370 Ok(players)
371 })
372 }
373
374 #[getter]
375 #[pyo3(name = "data")]
376 fn get_data_py<'a>(&self, py: Python<'a>) -> PyResult<PyObject> {
377 let client = self.clone();
378
379 let data = client.data::<RwLock<PyObject>>()?.read().clone_ref(py);
380
381 Ok(data)
382 }
383
384 #[setter]
385 #[pyo3(name = "data")]
386 fn set_data_py(&self, user_data: PyObject) -> PyResult<()> {
387 let client = self.clone();
388
389 *client.data::<RwLock<PyObject>>()?.write() = user_data;
390
391 Ok(())
392 }
393
394 #[pyo3(name = "handle_voice_server_update", signature = (guild_id, token, endpoint))]
395 fn handle_voice_server_update_py(
396 &self,
397 guild_id: super::model::PyGuildId,
398 token: String,
399 endpoint: Option<String>,
400 ) {
401 self.handle_voice_server_update(guild_id, token, endpoint);
402 }
403
404 #[pyo3(name = "handle_voice_state_update", signature = (guild_id, channel_id, user_id, session_id))]
405 fn handle_voice_state_update_py(
406 &self,
407 guild_id: super::model::PyGuildId,
408 channel_id: Option<super::model::PyChannelId>,
409 user_id: super::model::PyUserId,
410 session_id: String,
411 ) {
412 self.handle_voice_state_update(guild_id, channel_id, user_id, session_id);
413 }
414
415 #[pyo3(name = "get_connection_info")]
416 fn get_connection_info_py<'a>(
417 &self,
418 py: Python<'a>,
419 guild_id: super::model::PyGuildId,
420 timeout: u64,
421 ) -> PyResult<Bound<'a, PyAny>> {
422 let timeout = std::time::Duration::from_millis(timeout);
423 let client = self.clone();
424
425 pyo3_async_runtimes::tokio::future_into_py(py, async move {
426 let connection_info = client.get_connection_info(guild_id, timeout).await?;
427
428 Ok(connection_info)
429 })
430 }
431}