1use nostr_sdk::prelude::*;
4use std::sync::Arc;
5
6use crate::core::constants::*;
7use crate::core::error::{Error, Result};
8use crate::core::serializers;
9use crate::core::types::{EncryptionMode, JsonRpcMessage};
10use crate::core::validation;
11use crate::encryption;
12use crate::relay::RelayPoolTrait;
13
14const LOG_TARGET: &str = "contextvm_sdk::transport::base";
15
16pub struct BaseTransport {
22 pub relay_pool: Arc<dyn RelayPoolTrait>,
24 pub encryption_mode: EncryptionMode,
26 pub is_connected: bool,
28}
29
30impl BaseTransport {
31 pub async fn connect(&mut self, relay_urls: &[String]) -> Result<()> {
33 if self.is_connected {
34 return Ok(());
35 }
36 self.relay_pool.connect(relay_urls).await?;
37 self.is_connected = true;
38 Ok(())
39 }
40
41 pub async fn disconnect(&mut self) -> Result<()> {
43 if !self.is_connected {
44 return Ok(());
45 }
46 self.relay_pool.disconnect().await?;
47 self.is_connected = false;
48 Ok(())
49 }
50
51 pub async fn get_public_key(&self) -> Result<PublicKey> {
53 self.relay_pool.public_key().await
54 }
55
56 pub async fn subscribe_for_pubkey(&self, pubkey: &PublicKey) -> Result<()> {
61 let p_tag = pubkey.to_hex();
62 let now = Timestamp::now();
63
64 let ephemeral_filter = Filter::new()
65 .kind(Kind::Custom(CTXVM_MESSAGES_KIND))
66 .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone())
67 .since(now);
68
69 let gift_wrap_filter = Filter::new()
70 .kind(Kind::Custom(GIFT_WRAP_KIND))
71 .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone())
72 .since(now);
73
74 let ephemeral_gift_wrap_filter = Filter::new()
75 .kind(Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND))
76 .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag)
77 .since(now);
78
79 self.relay_pool
80 .subscribe(vec![
81 ephemeral_filter,
82 gift_wrap_filter,
83 ephemeral_gift_wrap_filter,
84 ])
85 .await
86 }
87
88 pub fn convert_event_to_mcp(&self, content: &str) -> Option<JsonRpcMessage> {
90 validation::validate_and_parse(content)
91 }
92
93 pub async fn create_signed_event(
95 &self,
96 message: &JsonRpcMessage,
97 kind: u16,
98 tags: Vec<Tag>,
99 ) -> Result<Event> {
100 let builder = serializers::mcp_to_nostr_event(message, kind, tags)?;
101 self.relay_pool.sign(builder).await
102 }
103
104 pub async fn prepare_mcp_message(
109 &self,
110 message: &JsonRpcMessage,
111 recipient: &PublicKey,
112 kind: u16,
113 tags: Vec<Tag>,
114 is_encrypted: Option<bool>,
115 gift_wrap_kind: Option<u16>,
116 ) -> Result<(EventId, Event)> {
117 let should_encrypt = self.should_encrypt(kind, is_encrypted);
118
119 let event = self.create_signed_event(message, kind, tags).await?;
120 let signed_event_id = event.id;
121
122 if should_encrypt {
123 let event_json =
124 serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?;
125 let signer = self
126 .relay_pool
127 .signer()
128 .await
129 .map_err(|e| Error::Encryption(e.to_string()))?;
130 let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND);
131 let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind(
132 &signer,
133 recipient,
134 &event_json,
135 selected_gift_wrap_kind,
136 )
137 .await?;
138 tracing::debug!(
139 target: LOG_TARGET,
140 signed_event_id = %signed_event_id,
141 envelope_id = %gift_wrap_event.id,
142 gift_wrap_kind = selected_gift_wrap_kind,
143 "Prepared encrypted MCP message"
144 );
145 Ok((signed_event_id, gift_wrap_event))
146 } else {
147 tracing::debug!(
148 target: LOG_TARGET,
149 signed_event_id = %signed_event_id,
150 "Prepared unencrypted MCP message"
151 );
152 Ok((signed_event_id, event))
153 }
154 }
155
156 pub async fn send_mcp_message(
161 &self,
162 message: &JsonRpcMessage,
163 recipient: &PublicKey,
164 kind: u16,
165 tags: Vec<Tag>,
166 is_encrypted: Option<bool>,
167 gift_wrap_kind: Option<u16>,
168 ) -> Result<EventId> {
169 let should_encrypt = self.should_encrypt(kind, is_encrypted);
170
171 let event = self.create_signed_event(message, kind, tags).await?;
172 let signed_event_id = event.id;
173
174 if should_encrypt {
175 let event_json =
178 serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?;
179 let signer = self
180 .relay_pool
181 .signer()
182 .await
183 .map_err(|e| Error::Encryption(e.to_string()))?;
184 let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND);
185 let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind(
186 &signer,
187 recipient,
188 &event_json,
189 selected_gift_wrap_kind,
190 )
191 .await?;
192 self.relay_pool.publish_event(&gift_wrap_event).await?;
193 tracing::debug!(
194 target: LOG_TARGET,
195 signed_event_id = %signed_event_id,
196 envelope_id = %gift_wrap_event.id,
197 gift_wrap_kind = selected_gift_wrap_kind,
198 "Sent encrypted MCP message"
199 );
200 } else {
201 self.relay_pool.publish_event(&event).await?;
202 tracing::debug!(
203 target: LOG_TARGET,
204 signed_event_id = %signed_event_id,
205 "Sent unencrypted MCP message"
206 );
207 }
208
209 Ok(signed_event_id)
210 }
211
212 pub fn should_encrypt(&self, kind: u16, is_encrypted: Option<bool>) -> bool {
214 if UNENCRYPTED_KINDS.contains(&kind) {
216 return false;
217 }
218
219 match self.encryption_mode {
220 EncryptionMode::Disabled => false,
221 EncryptionMode::Required => true,
222 EncryptionMode::Optional => is_encrypted.unwrap_or(true),
223 }
224 }
225
226 pub fn create_recipient_tags(pubkey: &PublicKey) -> Vec<Tag> {
228 vec![Tag::public_key(*pubkey)]
229 }
230
231 pub fn create_response_tags(pubkey: &PublicKey, event_id: &EventId) -> Vec<Tag> {
233 vec![Tag::public_key(*pubkey), Tag::event(*event_id)]
234 }
235
236 pub fn compose_outbound_tags(
239 base_tags: &[Tag],
240 discovery_tags: &[Tag],
241 negotiation_tags: &[Tag],
242 ) -> Vec<Tag> {
243 let mut tags =
244 Vec::with_capacity(base_tags.len() + discovery_tags.len() + negotiation_tags.len());
245 tags.extend_from_slice(base_tags);
246 tags.extend_from_slice(discovery_tags);
247 tags.extend_from_slice(negotiation_tags);
248 tags
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::core::types::*;
256
257 fn should_encrypt(mode: EncryptionMode, kind: u16, is_encrypted: Option<bool>) -> bool {
259 if UNENCRYPTED_KINDS.contains(&kind) {
260 return false;
261 }
262 match mode {
263 EncryptionMode::Disabled => false,
264 EncryptionMode::Required => true,
265 EncryptionMode::Optional => is_encrypted.unwrap_or(true),
266 }
267 }
268
269 #[test]
270 fn test_should_encrypt_disabled_mode() {
271 assert!(!should_encrypt(
272 EncryptionMode::Disabled,
273 CTXVM_MESSAGES_KIND,
274 None
275 ));
276 assert!(!should_encrypt(
277 EncryptionMode::Disabled,
278 CTXVM_MESSAGES_KIND,
279 Some(true)
280 ));
281 assert!(!should_encrypt(
282 EncryptionMode::Disabled,
283 CTXVM_MESSAGES_KIND,
284 Some(false)
285 ));
286 }
287
288 #[test]
289 fn test_should_encrypt_required_mode() {
290 assert!(should_encrypt(
291 EncryptionMode::Required,
292 CTXVM_MESSAGES_KIND,
293 None
294 ));
295 assert!(should_encrypt(
296 EncryptionMode::Required,
297 CTXVM_MESSAGES_KIND,
298 Some(false)
299 ));
300 assert!(should_encrypt(
301 EncryptionMode::Required,
302 CTXVM_MESSAGES_KIND,
303 Some(true)
304 ));
305 }
306
307 #[test]
308 fn test_should_encrypt_optional_mode() {
309 assert!(should_encrypt(
311 EncryptionMode::Optional,
312 CTXVM_MESSAGES_KIND,
313 None
314 ));
315 assert!(should_encrypt(
316 EncryptionMode::Optional,
317 CTXVM_MESSAGES_KIND,
318 Some(true)
319 ));
320 assert!(!should_encrypt(
321 EncryptionMode::Optional,
322 CTXVM_MESSAGES_KIND,
323 Some(false)
324 ));
325 }
326
327 #[test]
328 fn test_should_encrypt_announcement_kinds_never_encrypted() {
329 for &kind in UNENCRYPTED_KINDS {
330 assert!(!should_encrypt(EncryptionMode::Required, kind, Some(true)));
331 assert!(!should_encrypt(EncryptionMode::Optional, kind, Some(true)));
332 assert!(!should_encrypt(EncryptionMode::Disabled, kind, Some(true)));
333 }
334 }
335
336 #[test]
337 fn test_create_recipient_tags() {
338 let keys = Keys::generate();
339 let pubkey = keys.public_key();
340 let tags = BaseTransport::create_recipient_tags(&pubkey);
341 assert_eq!(tags.len(), 1);
342 let tag_vec = tags[0].clone().to_vec();
343 assert_eq!(tag_vec[0], "p");
344 assert_eq!(tag_vec[1], pubkey.to_hex());
345 }
346
347 #[test]
348 fn test_create_response_tags() {
349 let keys = Keys::generate();
350 let pubkey = keys.public_key();
351 let event_id =
353 EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000001")
354 .unwrap();
355 let tags = BaseTransport::create_response_tags(&pubkey, &event_id);
356 assert_eq!(tags.len(), 2);
357
358 let t0 = tags[0].clone().to_vec();
359 assert_eq!(t0[0], "p");
360 assert_eq!(t0[1], pubkey.to_hex());
361
362 let t1 = tags[1].clone().to_vec();
363 assert_eq!(t1[0], "e");
364 assert_eq!(t1[1], event_id.to_hex());
365 }
366
367 #[test]
368 fn test_convert_event_to_mcp_valid_request() {
369 let content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#;
373 let value: serde_json::Value = serde_json::from_str(content).unwrap();
374 let msg = crate::core::validation::validate_message(&value).unwrap();
375 assert!(msg.is_request());
376 assert_eq!(msg.method(), Some("tools/list"));
377 }
378
379 #[test]
380 fn test_convert_event_to_mcp_valid_notification() {
381 let content = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
382 let value: serde_json::Value = serde_json::from_str(content).unwrap();
383 let msg = crate::core::validation::validate_message(&value).unwrap();
384 assert!(msg.is_notification());
385 }
386
387 #[test]
388 fn test_convert_event_to_mcp_valid_response() {
389 let content = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
390 let value: serde_json::Value = serde_json::from_str(content).unwrap();
391 let msg = crate::core::validation::validate_message(&value).unwrap();
392 assert!(msg.is_response());
393 }
394
395 #[test]
396 fn test_convert_event_to_mcp_invalid_json() {
397 let content = "not json at all";
398 let result: std::result::Result<serde_json::Value, _> = serde_json::from_str(content);
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn test_convert_event_to_mcp_invalid_jsonrpc_version() {
404 let content = r#"{"jsonrpc":"1.0","id":1,"method":"test"}"#;
405 let value: serde_json::Value = serde_json::from_str(content).unwrap();
406 assert!(crate::core::validation::validate_message(&value).is_none());
407 }
408
409 #[test]
410 fn test_convert_event_to_mcp_oversized_message() {
411 let big = "x".repeat(MAX_MESSAGE_SIZE + 1);
412 assert!(!crate::core::validation::validate_message_size(&big));
413 }
414
415 fn make_custom_tag(name: &str) -> Tag {
418 Tag::custom(TagKind::Custom(name.into()), Vec::<String>::new())
419 }
420
421 #[test]
422 fn compose_outbound_tags_ordering() {
423 let keys = Keys::generate();
424 let base = vec![Tag::public_key(keys.public_key())];
425 let discovery = vec![make_custom_tag("support_encryption")];
426 let negotiation = vec![make_custom_tag("pmi")];
427
428 let result = BaseTransport::compose_outbound_tags(&base, &discovery, &negotiation);
429 assert_eq!(result.len(), 3);
430 assert_eq!(result[0].clone().to_vec()[0], "p");
431 assert_eq!(result[1].clone().to_vec()[0], "support_encryption");
432 assert_eq!(result[2].clone().to_vec()[0], "pmi");
433 }
434
435 #[test]
436 fn compose_outbound_tags_empty_discovery() {
437 let keys = Keys::generate();
438 let base = vec![Tag::public_key(keys.public_key())];
439 let negotiation = vec![make_custom_tag("pmi")];
440
441 let result = BaseTransport::compose_outbound_tags(&base, &[], &negotiation);
442 assert_eq!(result.len(), 2);
443 assert_eq!(result[0].clone().to_vec()[0], "p");
444 assert_eq!(result[1].clone().to_vec()[0], "pmi");
445 }
446
447 #[test]
448 fn compose_outbound_tags_all_empty() {
449 let result = BaseTransport::compose_outbound_tags(&[], &[], &[]);
450 assert!(result.is_empty());
451 }
452
453 #[test]
454 fn compose_outbound_tags_preserves_all_elements() {
455 let discovery = vec![
456 make_custom_tag("support_encryption"),
457 make_custom_tag("support_encryption_ephemeral"),
458 ];
459 let result = BaseTransport::compose_outbound_tags(&[], &discovery, &[]);
460 assert_eq!(result.len(), 2);
461 assert_eq!(result[0].clone().to_vec()[0], "support_encryption");
462 assert_eq!(
463 result[1].clone().to_vec()[0],
464 "support_encryption_ephemeral"
465 );
466 }
467}