lavalink_rs/python/
client.rs

1use crate::model::events::Events;
2use crate::model::http::UpdatePlayer;
3use crate::model::player::ConnectionInfo;
4use crate::prelude::PlayerContext;
5
6use futures::future::BoxFuture;
7use parking_lot::RwLock;
8use pyo3::prelude::*;
9use pyo3::types::PyList;
10
11#[pymodule]
12pub fn client(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
13    m.add_class::<crate::client::LavalinkClient>()?;
14
15    Ok(())
16}
17
18fn raw_event(
19    _: crate::client::LavalinkClient,
20    session_id: String,
21    event: &serde_json::Value,
22) -> BoxFuture<()> {
23    Box::pin(async move {
24        debug!("{:?} -> {:?}", session_id, event);
25    })
26}
27
28#[pymethods]
29impl crate::client::LavalinkClient {
30    #[pyo3(name = "new")]
31    #[pyo3(signature = (events, nodes, strategy, user_data=None))]
32    #[staticmethod]
33    fn new_py<'a>(
34        py: Python<'a>,
35        events: PyObject,
36        nodes: Vec<crate::node::NodeBuilder>,
37        strategy: super::model::client::NodeDistributionStrategyPy,
38        user_data: Option<PyObject>,
39    ) -> PyResult<Bound<'a, PyAny>> {
40        let current_loop = pyo3_async_runtimes::get_running_loop(py)?;
41        let loop_ref = PyObject::from(current_loop);
42
43        let event_handler = crate::python::event::EventHandler {
44            inner: events,
45            current_loop: loop_ref,
46        };
47
48        let events = Events {
49            raw: Some(raw_event),
50            event_handler: Some(event_handler),
51            ..Default::default()
52        };
53
54        pyo3_async_runtimes::tokio::future_into_py_with_locals(
55            py,
56            pyo3_async_runtimes::tokio::get_current_locals(py)?,
57            async move {
58                if let Some(data) = user_data {
59                    Ok(crate::client::LavalinkClient::new_with_data(
60                        events,
61                        nodes,
62                        strategy.inner,
63                        std::sync::Arc::new(RwLock::new(data)),
64                    )
65                    .await)
66                } else {
67                    Ok(crate::client::LavalinkClient::new_with_data(
68                        events,
69                        nodes,
70                        strategy.inner,
71                        std::sync::Arc::new(RwLock::new(Python::with_gil(|py| py.None()))),
72                    )
73                    .await)
74                }
75            },
76        )
77    }
78
79    #[pyo3(name = "create_player_context")]
80    #[pyo3(signature = (guild_id, endpoint, token, session_id, user_data=None))]
81    fn create_player_context_py<'a>(
82        &self,
83        py: Python<'a>,
84        guild_id: super::model::PyGuildId,
85        endpoint: String,
86        token: String,
87        session_id: String,
88        user_data: Option<PyObject>,
89    ) -> PyResult<Bound<'a, PyAny>> {
90        let client = self.clone();
91
92        pyo3_async_runtimes::tokio::future_into_py_with_locals(
93            py,
94            pyo3_async_runtimes::tokio::get_current_locals(py)?,
95            async move {
96                if let Some(data) = user_data {
97                    Ok(client
98                        .create_player_context_with_data(
99                            guild_id,
100                            ConnectionInfo {
101                                endpoint,
102                                token,
103                                session_id,
104                            },
105                            std::sync::Arc::new(RwLock::new(data)),
106                        )
107                        .await?)
108                } else {
109                    Ok(client
110                        .create_player_context_with_data(
111                            guild_id,
112                            ConnectionInfo {
113                                endpoint,
114                                token,
115                                session_id,
116                            },
117                            std::sync::Arc::new(RwLock::new(Python::with_gil(|py| py.None()))),
118                        )
119                        .await?)
120                }
121            },
122        )
123    }
124
125    #[pyo3(name = "create_player")]
126    fn create_player_py<'a>(
127        &self,
128        py: Python<'a>,
129        guild_id: super::model::PyGuildId,
130        endpoint: String,
131        token: String,
132        session_id: String,
133    ) -> PyResult<Bound<'a, PyAny>> {
134        let client = self.clone();
135
136        pyo3_async_runtimes::tokio::future_into_py(py, async move {
137            let player = client
138                .create_player(
139                    guild_id,
140                    ConnectionInfo {
141                        endpoint,
142                        token,
143                        session_id,
144                    },
145                )
146                .await?;
147
148            Ok(Python::with_gil(|_py| player))
149        })
150    }
151
152    #[pyo3(name = "get_player_context")]
153    fn get_player_context_py<'a>(
154        &self,
155        guild_id: super::model::PyGuildId,
156    ) -> PyResult<Option<PlayerContext>> {
157        let player = self.get_player_context(guild_id);
158
159        Ok(player)
160    }
161
162    #[pyo3(name = "get_node_by_index")]
163    fn get_node_by_index_py(&self, idx: usize) -> Option<super::node::Node> {
164        self.get_node_by_index(idx)
165            .map(|x| super::node::Node { inner: x })
166    }
167
168    #[pyo3(name = "get_node_for_guild")]
169    pub fn get_node_for_guild_py<'a>(
170        &self,
171        py: Python<'a>,
172        guild_id: super::model::PyGuildId,
173    ) -> PyResult<Bound<'a, PyAny>> {
174        let client = self.clone();
175
176        pyo3_async_runtimes::tokio::future_into_py(py, async move {
177            let res = client.get_node_for_guild(guild_id).await;
178
179            Ok(Python::with_gil(|_py| super::node::Node { inner: res }))
180        })
181    }
182
183    #[pyo3(name = "load_tracks")]
184    fn load_tracks_py<'a>(
185        &self,
186        py: Python<'a>,
187        guild_id: super::model::PyGuildId,
188        identifier: String,
189    ) -> PyResult<Bound<'a, PyAny>> {
190        let client = self.clone();
191
192        pyo3_async_runtimes::tokio::future_into_py(py, async move {
193            let tracks = client.load_tracks(guild_id, &identifier).await?;
194
195            use crate::model::track::TrackLoadData::*;
196
197            Python::with_gil(|py| {
198                let track_data: Option<PyObject> = match tracks.data {
199                    Some(Track(x)) => Some(x.into_pyobject(py).unwrap().into_any()),
200                    Some(Playlist(x)) => Some(x.into_pyobject(py).unwrap().into_any()),
201                    Some(Search(x)) => {
202                        let l = PyList::empty(py);
203                        for i in x {
204                            l.append(i.into_pyobject(py).unwrap())?;
205                        }
206
207                        Some(l.into_pyobject(py).unwrap().into_any())
208                    }
209                    Some(Error(x)) => Some(x.into_pyobject(py).unwrap().into_any()),
210                    None => None,
211                }
212                .map(|x| x.into());
213
214                Ok(super::model::track::Track {
215                    load_type: tracks.load_type,
216                    data: track_data,
217                })
218            })
219        })
220    }
221
222    #[pyo3(name = "delete_player")]
223    fn delete_player_py<'a>(
224        &self,
225        py: Python<'a>,
226        guild_id: super::model::PyGuildId,
227    ) -> PyResult<Bound<'a, PyAny>> {
228        let client = self.clone();
229
230        pyo3_async_runtimes::tokio::future_into_py(py, async move {
231            client.delete_player(guild_id).await?;
232
233            Ok(())
234        })
235    }
236
237    #[pyo3(name = "delete_all_player_contexts")]
238    fn delete_all_player_contexts_py<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
239        let client = self.clone();
240
241        pyo3_async_runtimes::tokio::future_into_py(py, async move {
242            client.delete_all_player_contexts().await?;
243
244            Ok(())
245        })
246    }
247
248    #[pyo3(name = "update_player")]
249    fn update_player_py<'a>(
250        &self,
251        py: Python<'a>,
252        guild_id: super::model::PyGuildId,
253        update_player: UpdatePlayer,
254        no_replace: bool,
255    ) -> PyResult<Bound<'a, PyAny>> {
256        let client = self.clone();
257
258        pyo3_async_runtimes::tokio::future_into_py(py, async move {
259            let player = client
260                .update_player(guild_id, &update_player, no_replace)
261                .await?;
262
263            Ok(player)
264        })
265    }
266
267    #[pyo3(name = "decode_track")]
268    fn decode_track_py<'a>(
269        &self,
270        py: Python<'a>,
271        guild_id: super::model::PyGuildId,
272        track: String,
273    ) -> PyResult<Bound<'a, PyAny>> {
274        let client = self.clone();
275
276        pyo3_async_runtimes::tokio::future_into_py(py, async move {
277            let track = client.decode_track(guild_id, &track).await?;
278
279            Ok(track)
280        })
281    }
282
283    #[pyo3(name = "decode_tracks")]
284    fn decode_tracks_py<'a>(
285        &self,
286        py: Python<'a>,
287        guild_id: super::model::PyGuildId,
288        tracks: Vec<String>,
289    ) -> PyResult<Bound<'a, PyAny>> {
290        let client = self.clone();
291
292        pyo3_async_runtimes::tokio::future_into_py(py, async move {
293            let tracks = client.decode_tracks(guild_id, &tracks).await?;
294
295            Ok(tracks)
296        })
297    }
298
299    #[pyo3(name = "request_version")]
300    fn request_version_py<'a>(
301        &self,
302        py: Python<'a>,
303        guild_id: super::model::PyGuildId,
304    ) -> PyResult<Bound<'a, PyAny>> {
305        let client = self.clone();
306
307        pyo3_async_runtimes::tokio::future_into_py(py, async move {
308            let version = client.request_version(guild_id).await?;
309
310            Ok(version)
311        })
312    }
313
314    #[pyo3(name = "request_info")]
315    fn request_info_py<'a>(
316        &self,
317        py: Python<'a>,
318        guild_id: super::model::PyGuildId,
319    ) -> PyResult<Bound<'a, PyAny>> {
320        let client = self.clone();
321
322        pyo3_async_runtimes::tokio::future_into_py(py, async move {
323            let info = client.request_info(guild_id).await?;
324
325            Ok(info)
326        })
327    }
328
329    #[pyo3(name = "request_stats")]
330    fn request_stats_py<'a>(
331        &self,
332        py: Python<'a>,
333        guild_id: super::model::PyGuildId,
334    ) -> PyResult<Bound<'a, PyAny>> {
335        let client = self.clone();
336
337        pyo3_async_runtimes::tokio::future_into_py(py, async move {
338            let stats = client.request_stats(guild_id).await?;
339
340            Ok(stats)
341        })
342    }
343
344    #[pyo3(name = "request_player")]
345    fn request_player_py<'a>(
346        &self,
347        py: Python<'a>,
348        guild_id: super::model::PyGuildId,
349    ) -> PyResult<Bound<'a, PyAny>> {
350        let client = self.clone();
351
352        pyo3_async_runtimes::tokio::future_into_py(py, async move {
353            let player = client.request_player(guild_id).await?;
354
355            Ok(player)
356        })
357    }
358
359    #[pyo3(name = "request_all_players")]
360    fn request_all_players_py<'a>(
361        &self,
362        py: Python<'a>,
363        guild_id: super::model::PyGuildId,
364    ) -> PyResult<Bound<'a, PyAny>> {
365        let client = self.clone();
366
367        pyo3_async_runtimes::tokio::future_into_py(py, async move {
368            let players = client.request_all_players(guild_id).await?;
369
370            Ok(players)
371        })
372    }
373
374    #[getter]
375    #[pyo3(name = "data")]
376    fn get_data_py<'a>(&self, py: Python<'a>) -> PyResult<PyObject> {
377        let client = self.clone();
378
379        let data = client.data::<RwLock<PyObject>>()?.read().clone_ref(py);
380
381        Ok(data)
382    }
383
384    #[setter]
385    #[pyo3(name = "data")]
386    fn set_data_py(&self, user_data: PyObject) -> PyResult<()> {
387        let client = self.clone();
388
389        *client.data::<RwLock<PyObject>>()?.write() = user_data;
390
391        Ok(())
392    }
393
394    #[pyo3(name = "handle_voice_server_update", signature = (guild_id, token, endpoint))]
395    fn handle_voice_server_update_py(
396        &self,
397        guild_id: super::model::PyGuildId,
398        token: String,
399        endpoint: Option<String>,
400    ) {
401        self.handle_voice_server_update(guild_id, token, endpoint);
402    }
403
404    #[pyo3(name = "handle_voice_state_update", signature = (guild_id, channel_id, user_id, session_id))]
405    fn handle_voice_state_update_py(
406        &self,
407        guild_id: super::model::PyGuildId,
408        channel_id: Option<super::model::PyChannelId>,
409        user_id: super::model::PyUserId,
410        session_id: String,
411    ) {
412        self.handle_voice_state_update(guild_id, channel_id, user_id, session_id);
413    }
414
415    #[pyo3(name = "get_connection_info")]
416    fn get_connection_info_py<'a>(
417        &self,
418        py: Python<'a>,
419        guild_id: super::model::PyGuildId,
420        timeout: u64,
421    ) -> PyResult<Bound<'a, PyAny>> {
422        let timeout = std::time::Duration::from_millis(timeout);
423        let client = self.clone();
424
425        pyo3_async_runtimes::tokio::future_into_py(py, async move {
426            let connection_info = client.get_connection_info(guild_id, timeout).await?;
427
428            Ok(connection_info)
429        })
430    }
431}